model_invocation_utils.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """
  2. For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
  3. Therefore, a model manager is needed to list/invoke/validate models.
  4. """
  5. import json
  6. from typing import cast
  7. from core.model_manager import ModelManager
  8. from core.model_runtime.entities.llm_entities import LLMResult
  9. from core.model_runtime.entities.message_entities import PromptMessage
  10. from core.model_runtime.entities.model_entities import ModelType
  11. from core.model_runtime.errors.invoke import (
  12. InvokeAuthorizationError,
  13. InvokeBadRequestError,
  14. InvokeConnectionError,
  15. InvokeRateLimitError,
  16. InvokeServerUnavailableError,
  17. )
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from extensions.ext_database import db
  21. from models.tools import ToolModelInvoke
  22. class InvokeModelError(Exception):
  23. pass
  24. class ModelInvocationUtils:
  25. @staticmethod
  26. def get_max_llm_context_tokens(
  27. tenant_id: str,
  28. ) -> int:
  29. """
  30. get max llm context tokens of the model
  31. """
  32. model_manager = ModelManager()
  33. model_instance = model_manager.get_default_model_instance(
  34. tenant_id=tenant_id, model_type=ModelType.LLM,
  35. )
  36. if not model_instance:
  37. raise InvokeModelError('Model not found')
  38. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  39. schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  40. if not schema:
  41. raise InvokeModelError('No model schema found')
  42. max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
  43. if max_tokens is None:
  44. return 2048
  45. return max_tokens
  46. @staticmethod
  47. def calculate_tokens(
  48. tenant_id: str,
  49. prompt_messages: list[PromptMessage]
  50. ) -> int:
  51. """
  52. calculate tokens from prompt messages and model parameters
  53. """
  54. # get model instance
  55. model_manager = ModelManager()
  56. model_instance = model_manager.get_default_model_instance(
  57. tenant_id=tenant_id, model_type=ModelType.LLM
  58. )
  59. if not model_instance:
  60. raise InvokeModelError('Model not found')
  61. # get tokens
  62. tokens = model_instance.get_llm_num_tokens(prompt_messages)
  63. return tokens
  64. @staticmethod
  65. def invoke(
  66. user_id: str, tenant_id: str,
  67. tool_type: str, tool_name: str,
  68. prompt_messages: list[PromptMessage]
  69. ) -> LLMResult:
  70. """
  71. invoke model with parameters in user's own context
  72. :param user_id: user id
  73. :param tenant_id: tenant id, the tenant id of the creator of the tool
  74. :param tool_provider: tool provider
  75. :param tool_id: tool id
  76. :param tool_name: tool name
  77. :param provider: model provider
  78. :param model: model name
  79. :param model_parameters: model parameters
  80. :param prompt_messages: prompt messages
  81. :return: AssistantPromptMessage
  82. """
  83. # get model manager
  84. model_manager = ModelManager()
  85. # get model instance
  86. model_instance = model_manager.get_default_model_instance(
  87. tenant_id=tenant_id, model_type=ModelType.LLM,
  88. )
  89. # get prompt tokens
  90. prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
  91. model_parameters = {
  92. 'temperature': 0.8,
  93. 'top_p': 0.8,
  94. }
  95. # create tool model invoke
  96. tool_model_invoke = ToolModelInvoke(
  97. user_id=user_id,
  98. tenant_id=tenant_id,
  99. provider=model_instance.provider,
  100. tool_type=tool_type,
  101. tool_name=tool_name,
  102. model_parameters=json.dumps(model_parameters),
  103. prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
  104. model_response='',
  105. prompt_tokens=prompt_tokens,
  106. answer_tokens=0,
  107. answer_unit_price=0,
  108. answer_price_unit=0,
  109. provider_response_latency=0,
  110. total_price=0,
  111. currency='USD',
  112. )
  113. db.session.add(tool_model_invoke)
  114. db.session.commit()
  115. try:
  116. response: LLMResult = model_instance.invoke_llm(
  117. prompt_messages=prompt_messages,
  118. model_parameters=model_parameters,
  119. tools=[], stop=[], stream=False, user=user_id, callbacks=[]
  120. )
  121. except InvokeRateLimitError as e:
  122. raise InvokeModelError(f'Invoke rate limit error: {e}')
  123. except InvokeBadRequestError as e:
  124. raise InvokeModelError(f'Invoke bad request error: {e}')
  125. except InvokeConnectionError as e:
  126. raise InvokeModelError(f'Invoke connection error: {e}')
  127. except InvokeAuthorizationError as e:
  128. raise InvokeModelError('Invoke authorization error')
  129. except InvokeServerUnavailableError as e:
  130. raise InvokeModelError(f'Invoke server unavailable error: {e}')
  131. except Exception as e:
  132. raise InvokeModelError(f'Invoke error: {e}')
  133. # update tool model invoke
  134. tool_model_invoke.model_response = response.message.content
  135. if response.usage:
  136. tool_model_invoke.answer_tokens = response.usage.completion_tokens
  137. tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
  138. tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
  139. tool_model_invoke.provider_response_latency = response.usage.latency
  140. tool_model_invoke.total_price = response.usage.total_price
  141. tool_model_invoke.currency = response.usage.currency
  142. db.session.commit()
  143. return response