|  | @@ -1,9 +1,23 @@
 | 
	
		
			
				|  |  |  from collections.abc import Generator
 | 
	
		
			
				|  |  | +from decimal import Decimal
 | 
	
		
			
				|  |  |  from typing import Optional, Union
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from core.model_runtime.entities.llm_entities import LLMResult
 | 
	
		
			
				|  |  | -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
 | 
	
		
			
				|  |  | -from core.model_runtime.entities.model_entities import AIModelEntity
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.common_entities import I18nObject
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.message_entities import (
 | 
	
		
			
				|  |  | +    PromptMessage,
 | 
	
		
			
				|  |  | +    PromptMessageTool,
 | 
	
		
			
				|  |  | +)
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.model_entities import (
 | 
	
		
			
				|  |  | +    AIModelEntity,
 | 
	
		
			
				|  |  | +    DefaultParameterName,
 | 
	
		
			
				|  |  | +    FetchFrom,
 | 
	
		
			
				|  |  | +    ModelPropertyKey,
 | 
	
		
			
				|  |  | +    ModelType,
 | 
	
		
			
				|  |  | +    ParameterRule,
 | 
	
		
			
				|  |  | +    ParameterType,
 | 
	
		
			
				|  |  | +    PriceConfig,
 | 
	
		
			
				|  |  | +)
 | 
	
		
			
				|  |  |  from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -36,8 +50,98 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
 | 
	
		
			
				|  |  |          cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        return super().get_customizable_model_schema(model, cred_with_endpoint)
 | 
	
		
			
				|  |  | +        REPETITION_PENALTY = "repetition_penalty"
 | 
	
		
			
				|  |  | +        TOP_K = "top_k"
 | 
	
		
			
				|  |  | +        features = []
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        entity = AIModelEntity(
 | 
	
		
			
				|  |  | +            model=model,
 | 
	
		
			
				|  |  | +            label=I18nObject(en_US=model),
 | 
	
		
			
				|  |  | +            model_type=ModelType.LLM,
 | 
	
		
			
				|  |  | +            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
 | 
	
		
			
				|  |  | +            features=features,
 | 
	
		
			
				|  |  | +            model_properties={
 | 
	
		
			
				|  |  | +                ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")),
 | 
	
		
			
				|  |  | +                ModelPropertyKey.MODE: cred_with_endpoint.get('mode'),
 | 
	
		
			
				|  |  | +            },
 | 
	
		
			
				|  |  | +            parameter_rules=[
 | 
	
		
			
				|  |  | +                ParameterRule(
 | 
	
		
			
				|  |  | +                    name=DefaultParameterName.TEMPERATURE.value,
 | 
	
		
			
				|  |  | +                    label=I18nObject(en_US="Temperature"),
 | 
	
		
			
				|  |  | +                    type=ParameterType.FLOAT,
 | 
	
		
			
				|  |  | +                    default=float(cred_with_endpoint.get('temperature', 0.7)),
 | 
	
		
			
				|  |  | +                    min=0,
 | 
	
		
			
				|  |  | +                    max=2,
 | 
	
		
			
				|  |  | +                    precision=2
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                ParameterRule(
 | 
	
		
			
				|  |  | +                    name=DefaultParameterName.TOP_P.value,
 | 
	
		
			
				|  |  | +                    label=I18nObject(en_US="Top P"),
 | 
	
		
			
				|  |  | +                    type=ParameterType.FLOAT,
 | 
	
		
			
				|  |  | +                    default=float(cred_with_endpoint.get('top_p', 1)),
 | 
	
		
			
				|  |  | +                    min=0,
 | 
	
		
			
				|  |  | +                    max=1,
 | 
	
		
			
				|  |  | +                    precision=2
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                ParameterRule(
 | 
	
		
			
				|  |  | +                    name=TOP_K,
 | 
	
		
			
				|  |  | +                    label=I18nObject(en_US="Top K"),
 | 
	
		
			
				|  |  | +                    type=ParameterType.INT,
 | 
	
		
			
				|  |  | +                    default=int(cred_with_endpoint.get('top_k', 50)),
 | 
	
		
			
				|  |  | +                    min=-2147483647,
 | 
	
		
			
				|  |  | +                    max=2147483647,
 | 
	
		
			
				|  |  | +                    precision=0
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                ParameterRule(
 | 
	
		
			
				|  |  | +                    name=REPETITION_PENALTY,
 | 
	
		
			
				|  |  | +                    label=I18nObject(en_US="Repetition Penalty"),
 | 
	
		
			
				|  |  | +                    type=ParameterType.FLOAT,
 | 
	
		
			
				|  |  | +                    default=float(cred_with_endpoint.get('repetition_penalty', 1)),
 | 
	
		
			
				|  |  | +                    min=-3.4,
 | 
	
		
			
				|  |  | +                    max=3.4,
 | 
	
		
			
				|  |  | +                    precision=1
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                ParameterRule(
 | 
	
		
			
				|  |  | +                    name=DefaultParameterName.MAX_TOKENS.value,
 | 
	
		
			
				|  |  | +                    label=I18nObject(en_US="Max Tokens"),
 | 
	
		
			
				|  |  | +                    type=ParameterType.INT,
 | 
	
		
			
				|  |  | +                    default=512,
 | 
	
		
			
				|  |  | +                    min=1,
 | 
	
		
			
				|  |  | +                    max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)),
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                ParameterRule(
 | 
	
		
			
				|  |  | +                    name=DefaultParameterName.FREQUENCY_PENALTY.value,
 | 
	
		
			
				|  |  | +                    label=I18nObject(en_US="Frequency Penalty"),
 | 
	
		
			
				|  |  | +                    type=ParameterType.FLOAT,
 | 
	
		
			
				|  |  | +                    default=float(credentials.get('frequency_penalty', 0)),
 | 
	
		
			
				|  |  | +                    min=-2,
 | 
	
		
			
				|  |  | +                    max=2
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                ParameterRule(
 | 
	
		
			
				|  |  | +                    name=DefaultParameterName.PRESENCE_PENALTY.value,
 | 
	
		
			
				|  |  | +                    label=I18nObject(en_US="Presence Penalty"),
 | 
	
		
			
				|  |  | +                    type=ParameterType.FLOAT,
 | 
	
		
			
				|  |  | +                    default=float(credentials.get('presence_penalty', 0)),
 | 
	
		
			
				|  |  | +                    min=-2,
 | 
	
		
			
				|  |  | +                    max=2
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +            ],
 | 
	
		
			
				|  |  | +            pricing=PriceConfig(
 | 
	
		
			
				|  |  | +                input=Decimal(cred_with_endpoint.get('input_price', 0)),
 | 
	
		
			
				|  |  | +                output=Decimal(cred_with_endpoint.get('output_price', 0)),
 | 
	
		
			
				|  |  | +                unit=Decimal(cred_with_endpoint.get('unit', 0)),
 | 
	
		
			
				|  |  | +                currency=cred_with_endpoint.get('currency', "USD")
 | 
	
		
			
				|  |  | +            ),
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if cred_with_endpoint['mode'] == 'chat':
 | 
	
		
			
				|  |  | +            entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
 | 
	
		
			
				|  |  | +        elif cred_with_endpoint['mode'] == 'completion':
 | 
	
		
			
				|  |  | +            entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return entity
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  |                         tools: Optional[list[PromptMessageTool]] = None) -> int:
 |