llm.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import logging
  2. from collections.abc import Generator
  3. from typing import Optional, Union
  4. from dify_plugin.entities import I18nObject
  5. from dify_plugin.errors.model import (
  6. CredentialsValidateFailedError,
  7. )
  8. from dify_plugin.entities.model import (
  9. AIModelEntity,
  10. FetchFrom,
  11. ModelType,
  12. )
  13. from dify_plugin.entities.model.llm import (
  14. LLMResult,
  15. )
  16. from dify_plugin.entities.model.message import (
  17. PromptMessage,
  18. PromptMessageTool,
  19. )
  20. logger = logging.getLogger(__name__)
  21. class {{ .PluginName | SnakeToCamel }}LargeLanguageModel(LargeLanguageModel):
  22. """
  23. Model class for {{ .PluginName }} large language model.
  24. """
  25. def _invoke(
  26. self,
  27. model: str,
  28. credentials: dict,
  29. prompt_messages: list[PromptMessage],
  30. model_parameters: dict,
  31. tools: Optional[list[PromptMessageTool]] = None,
  32. stop: Optional[list[str]] = None,
  33. stream: bool = True,
  34. user: Optional[str] = None,
  35. ) -> Union[LLMResult, Generator]:
  36. """
  37. Invoke large language model
  38. :param model: model name
  39. :param credentials: model credentials
  40. :param prompt_messages: prompt messages
  41. :param model_parameters: model parameters
  42. :param tools: tools for tool calling
  43. :param stop: stop words
  44. :param stream: is stream response
  45. :param user: unique user id
  46. :return: full response or stream response chunk generator result
  47. """
  48. pass
  49. def get_num_tokens(
  50. self,
  51. model: str,
  52. credentials: dict,
  53. prompt_messages: list[PromptMessage],
  54. tools: Optional[list[PromptMessageTool]] = None,
  55. ) -> int:
  56. """
  57. Get number of tokens for given prompt messages
  58. :param model: model name
  59. :param credentials: model credentials
  60. :param prompt_messages: prompt messages
  61. :param tools: tools for tool calling
  62. :return:
  63. """
  64. return 0
  65. def validate_credentials(self, model: str, credentials: dict) -> None:
  66. """
  67. Validate model credentials
  68. :param model: model name
  69. :param credentials: model credentials
  70. :return:
  71. """
  72. try:
  73. pass
  74. except Exception as ex:
  75. raise CredentialsValidateFailedError(str(ex))
  76. def get_customizable_model_schema(
  77. self, model: str, credentials: dict
  78. ) -> AIModelEntity:
  79. """
  80. If your model supports fine-tuning, this method returns the schema of the base model
  81. but renamed to the fine-tuned model name.
  82. :param model: model name
  83. :param credentials: credentials
  84. :return: model schema
  85. """
  86. entity = AIModelEntity(
  87. model=model,
  88. label=I18nObject(zh_Hans=model, en_US=model),
  89. model_type=ModelType.LLM,
  90. features=[],
  91. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  92. model_properties={},
  93. parameter_rules=[],
  94. )
  95. return entity