plugin_model.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import datetime
  2. import uuid
  3. from collections.abc import Generator, Sequence
  4. from decimal import Decimal
  5. from json import dumps
  6. # import monkeypatch
  7. from typing import Optional
  8. from core.model_runtime.entities.common_entities import I18nObject
  9. from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  10. from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
  11. from core.model_runtime.entities.model_entities import (
  12. AIModelEntity,
  13. FetchFrom,
  14. ModelFeature,
  15. ModelPropertyKey,
  16. ModelType,
  17. )
  18. from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
  19. from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
  20. from core.plugin.manager.model import PluginModelManager
  21. class MockModelClass(PluginModelManager):
  22. def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
  23. """
  24. Fetch model providers for the given tenant.
  25. """
  26. return [
  27. PluginModelProviderEntity(
  28. id=uuid.uuid4().hex,
  29. created_at=datetime.datetime.now(),
  30. updated_at=datetime.datetime.now(),
  31. provider="openai",
  32. tenant_id=tenant_id,
  33. plugin_unique_identifier="langgenius/openai/openai",
  34. plugin_id="langgenius/openai",
  35. declaration=ProviderEntity(
  36. provider="openai",
  37. label=I18nObject(
  38. en_US="OpenAI",
  39. zh_Hans="OpenAI",
  40. ),
  41. description=I18nObject(
  42. en_US="OpenAI",
  43. zh_Hans="OpenAI",
  44. ),
  45. icon_small=I18nObject(
  46. en_US="https://example.com/icon_small.png",
  47. zh_Hans="https://example.com/icon_small.png",
  48. ),
  49. icon_large=I18nObject(
  50. en_US="https://example.com/icon_large.png",
  51. zh_Hans="https://example.com/icon_large.png",
  52. ),
  53. supported_model_types=[ModelType.LLM],
  54. configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
  55. models=[
  56. AIModelEntity(
  57. model="gpt-3.5-turbo",
  58. label=I18nObject(
  59. en_US="gpt-3.5-turbo",
  60. zh_Hans="gpt-3.5-turbo",
  61. ),
  62. model_type=ModelType.LLM,
  63. fetch_from=FetchFrom.PREDEFINED_MODEL,
  64. model_properties={},
  65. features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL],
  66. ),
  67. AIModelEntity(
  68. model="gpt-3.5-turbo-instruct",
  69. label=I18nObject(
  70. en_US="gpt-3.5-turbo-instruct",
  71. zh_Hans="gpt-3.5-turbo-instruct",
  72. ),
  73. model_type=ModelType.LLM,
  74. fetch_from=FetchFrom.PREDEFINED_MODEL,
  75. model_properties={
  76. ModelPropertyKey.MODE: LLMMode.COMPLETION,
  77. },
  78. features=[],
  79. ),
  80. ],
  81. ),
  82. )
  83. ]
  84. def get_model_schema(
  85. self,
  86. tenant_id: str,
  87. user_id: str,
  88. plugin_id: str,
  89. provider: str,
  90. model_type: str,
  91. model: str,
  92. credentials: dict,
  93. ) -> AIModelEntity | None:
  94. """
  95. Get model schema
  96. """
  97. return AIModelEntity(
  98. model=model,
  99. label=I18nObject(
  100. en_US="OpenAI",
  101. zh_Hans="OpenAI",
  102. ),
  103. model_type=ModelType(model_type),
  104. fetch_from=FetchFrom.PREDEFINED_MODEL,
  105. model_properties={},
  106. features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL] if model == "gpt-3.5-turbo" else [],
  107. )
  108. @staticmethod
  109. def generate_function_call(
  110. tools: Optional[list[PromptMessageTool]],
  111. ) -> Optional[AssistantPromptMessage.ToolCall]:
  112. if not tools or len(tools) == 0:
  113. return None
  114. function: PromptMessageTool = tools[0]
  115. function_name = function.name
  116. function_parameters = function.parameters
  117. function_parameters_type = function_parameters["type"]
  118. if function_parameters_type != "object":
  119. return None
  120. function_parameters_properties = function_parameters["properties"]
  121. function_parameters_required = function_parameters["required"]
  122. parameters = {}
  123. for parameter_name, parameter in function_parameters_properties.items():
  124. if parameter_name not in function_parameters_required:
  125. continue
  126. parameter_type = parameter["type"]
  127. if parameter_type == "string":
  128. if "enum" in parameter:
  129. if len(parameter["enum"]) == 0:
  130. continue
  131. parameters[parameter_name] = parameter["enum"][0]
  132. else:
  133. parameters[parameter_name] = "kawaii"
  134. elif parameter_type == "integer":
  135. parameters[parameter_name] = 114514
  136. elif parameter_type == "number":
  137. parameters[parameter_name] = 1919810.0
  138. elif parameter_type == "boolean":
  139. parameters[parameter_name] = True
  140. return AssistantPromptMessage.ToolCall(
  141. id=str(uuid.uuid4()),
  142. type="function",
  143. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  144. name=function_name,
  145. arguments=dumps(parameters),
  146. ),
  147. )
  148. @staticmethod
  149. def mocked_chat_create_sync(
  150. model: str,
  151. prompt_messages: list[PromptMessage],
  152. tools: Optional[list[PromptMessageTool]] = None,
  153. ) -> LLMResult:
  154. tool_call = MockModelClass.generate_function_call(tools=tools)
  155. return LLMResult(
  156. id=str(uuid.uuid4()),
  157. model=model,
  158. prompt_messages=prompt_messages,
  159. message=AssistantPromptMessage(content="elaina", tool_calls=[tool_call] if tool_call else []),
  160. usage=LLMUsage(
  161. prompt_tokens=2,
  162. completion_tokens=1,
  163. total_tokens=3,
  164. prompt_unit_price=Decimal(0.0001),
  165. completion_unit_price=Decimal(0.0002),
  166. prompt_price_unit=Decimal(1),
  167. prompt_price=Decimal(0.0001),
  168. completion_price_unit=Decimal(1),
  169. completion_price=Decimal(0.0002),
  170. total_price=Decimal(0.0003),
  171. currency="USD",
  172. latency=0.001,
  173. ),
  174. )
  175. @staticmethod
  176. def mocked_chat_create_stream(
  177. model: str,
  178. prompt_messages: list[PromptMessage],
  179. tools: Optional[list[PromptMessageTool]] = None,
  180. ) -> Generator[LLMResultChunk, None, None]:
  181. tool_call = MockModelClass.generate_function_call(tools=tools)
  182. full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
  183. for i in range(0, len(full_text) + 1):
  184. if i == len(full_text):
  185. yield LLMResultChunk(
  186. model=model,
  187. prompt_messages=prompt_messages,
  188. delta=LLMResultChunkDelta(
  189. index=0,
  190. message=AssistantPromptMessage(
  191. content="",
  192. tool_calls=[tool_call] if tool_call else [],
  193. ),
  194. ),
  195. )
  196. else:
  197. yield LLMResultChunk(
  198. model=model,
  199. prompt_messages=prompt_messages,
  200. delta=LLMResultChunkDelta(
  201. index=0,
  202. message=AssistantPromptMessage(
  203. content=full_text[i],
  204. tool_calls=[tool_call] if tool_call else [],
  205. ),
  206. usage=LLMUsage(
  207. prompt_tokens=2,
  208. completion_tokens=17,
  209. total_tokens=19,
  210. prompt_unit_price=Decimal(0.0001),
  211. completion_unit_price=Decimal(0.0002),
  212. prompt_price_unit=Decimal(1),
  213. prompt_price=Decimal(0.0001),
  214. completion_price_unit=Decimal(1),
  215. completion_price=Decimal(0.0002),
  216. total_price=Decimal(0.0003),
  217. currency="USD",
  218. latency=0.001,
  219. ),
  220. ),
  221. )
  222. def invoke_llm(
  223. self: PluginModelManager,
  224. *,
  225. tenant_id: str,
  226. user_id: str,
  227. plugin_id: str,
  228. provider: str,
  229. model: str,
  230. credentials: dict,
  231. prompt_messages: list[PromptMessage],
  232. model_parameters: Optional[dict] = None,
  233. tools: Optional[list[PromptMessageTool]] = None,
  234. stop: Optional[list[str]] = None,
  235. stream: bool = True,
  236. ):
  237. return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)