123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- """
- For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
- Therefore, a model manager is needed to list/invoke/validate models.
- """
- import json
- from typing import cast
- from core.model_manager import ModelManager
- from core.model_runtime.entities.llm_entities import LLMResult
- from core.model_runtime.entities.message_entities import PromptMessage
- from core.model_runtime.entities.model_entities import ModelType
- from core.model_runtime.errors.invoke import (
- InvokeAuthorizationError,
- InvokeBadRequestError,
- InvokeConnectionError,
- InvokeRateLimitError,
- InvokeServerUnavailableError,
- )
- from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
- from core.model_runtime.utils.encoders import jsonable_encoder
- from extensions.ext_database import db
- from models.tools import ToolModelInvoke
- class InvokeModelError(Exception):
- pass
- class ModelInvocationUtils:
- @staticmethod
- def get_max_llm_context_tokens(
- tenant_id: str,
- ) -> int:
- """
- get max llm context tokens of the model
- """
- model_manager = ModelManager()
- model_instance = model_manager.get_default_model_instance(
- tenant_id=tenant_id, model_type=ModelType.LLM,
- )
- if not model_instance:
- raise InvokeModelError('Model not found')
-
- llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
- schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
- if not schema:
- raise InvokeModelError('No model schema found')
- max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
- if max_tokens is None:
- return 2048
-
- return max_tokens
- @staticmethod
- def calculate_tokens(
- tenant_id: str,
- prompt_messages: list[PromptMessage]
- ) -> int:
- """
- calculate tokens from prompt messages and model parameters
- """
- # get model instance
- model_manager = ModelManager()
- model_instance = model_manager.get_default_model_instance(
- tenant_id=tenant_id, model_type=ModelType.LLM
- )
- if not model_instance:
- raise InvokeModelError('Model not found')
-
- # get tokens
- tokens = model_instance.get_llm_num_tokens(prompt_messages)
- return tokens
- @staticmethod
- def invoke(
- user_id: str, tenant_id: str,
- tool_type: str, tool_name: str,
- prompt_messages: list[PromptMessage]
- ) -> LLMResult:
- """
- invoke model with parameters in user's own context
- :param user_id: user id
- :param tenant_id: tenant id, the tenant id of the creator of the tool
- :param tool_provider: tool provider
- :param tool_id: tool id
- :param tool_name: tool name
- :param provider: model provider
- :param model: model name
- :param model_parameters: model parameters
- :param prompt_messages: prompt messages
- :return: AssistantPromptMessage
- """
- # get model manager
- model_manager = ModelManager()
- # get model instance
- model_instance = model_manager.get_default_model_instance(
- tenant_id=tenant_id, model_type=ModelType.LLM,
- )
- # get prompt tokens
- prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
- model_parameters = {
- 'temperature': 0.8,
- 'top_p': 0.8,
- }
- # create tool model invoke
- tool_model_invoke = ToolModelInvoke(
- user_id=user_id,
- tenant_id=tenant_id,
- provider=model_instance.provider,
- tool_type=tool_type,
- tool_name=tool_name,
- model_parameters=json.dumps(model_parameters),
- prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
- model_response='',
- prompt_tokens=prompt_tokens,
- answer_tokens=0,
- answer_unit_price=0,
- answer_price_unit=0,
- provider_response_latency=0,
- total_price=0,
- currency='USD',
- )
- db.session.add(tool_model_invoke)
- db.session.commit()
- try:
- response: LLMResult = model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=[], stop=[], stream=False, user=user_id, callbacks=[]
- )
- except InvokeRateLimitError as e:
- raise InvokeModelError(f'Invoke rate limit error: {e}')
- except InvokeBadRequestError as e:
- raise InvokeModelError(f'Invoke bad request error: {e}')
- except InvokeConnectionError as e:
- raise InvokeModelError(f'Invoke connection error: {e}')
- except InvokeAuthorizationError as e:
- raise InvokeModelError('Invoke authorization error')
- except InvokeServerUnavailableError as e:
- raise InvokeModelError(f'Invoke server unavailable error: {e}')
- except Exception as e:
- raise InvokeModelError(f'Invoke error: {e}')
- # update tool model invoke
- tool_model_invoke.model_response = response.message.content
- if response.usage:
- tool_model_invoke.answer_tokens = response.usage.completion_tokens
- tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
- tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
- tool_model_invoke.provider_response_latency = response.usage.latency
- tool_model_invoke.total_price = response.usage.total_price
- tool_model_invoke.currency = response.usage.currency
- db.session.commit()
- return response
|