request.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from typing import Any, Literal, Optional
  2. from pydantic import BaseModel, ConfigDict, Field, field_validator
  3. from core.entities.provider_entities import BasicProviderConfig
  4. from core.model_runtime.entities.message_entities import (
  5. AssistantPromptMessage,
  6. PromptMessage,
  7. PromptMessageRole,
  8. PromptMessageTool,
  9. SystemPromptMessage,
  10. ToolPromptMessage,
  11. UserPromptMessage,
  12. )
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from core.workflow.nodes.parameter_extractor.entities import (
  15. ModelConfig as ParameterExtractorModelConfig,
  16. )
  17. from core.workflow.nodes.parameter_extractor.entities import (
  18. ParameterConfig,
  19. )
  20. from core.workflow.nodes.question_classifier.entities import (
  21. ClassConfig,
  22. )
  23. from core.workflow.nodes.question_classifier.entities import (
  24. ModelConfig as QuestionClassifierModelConfig,
  25. )
  26. class RequestInvokeTool(BaseModel):
  27. """
  28. Request to invoke a tool
  29. """
  30. class BaseRequestInvokeModel(BaseModel):
  31. provider: str
  32. model: str
  33. model_type: ModelType
  34. model_config = ConfigDict(protected_namespaces=())
  35. class RequestInvokeLLM(BaseRequestInvokeModel):
  36. """
  37. Request to invoke LLM
  38. """
  39. model_type: ModelType = ModelType.LLM
  40. mode: str
  41. model_parameters: dict[str, Any] = Field(default_factory=dict)
  42. prompt_messages: list[PromptMessage] = Field(default_factory=list)
  43. tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
  44. stop: Optional[list[str]] = Field(default_factory=list)
  45. stream: Optional[bool] = False
  46. model_config = ConfigDict(protected_namespaces=())
  47. @field_validator("prompt_messages", mode="before")
  48. @classmethod
  49. def convert_prompt_messages(cls, v):
  50. if not isinstance(v, list):
  51. raise ValueError("prompt_messages must be a list")
  52. for i in range(len(v)):
  53. if v[i]["role"] == PromptMessageRole.USER.value:
  54. v[i] = UserPromptMessage(**v[i])
  55. elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
  56. v[i] = AssistantPromptMessage(**v[i])
  57. elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
  58. v[i] = SystemPromptMessage(**v[i])
  59. elif v[i]["role"] == PromptMessageRole.TOOL.value:
  60. v[i] = ToolPromptMessage(**v[i])
  61. else:
  62. v[i] = PromptMessage(**v[i])
  63. return v
  64. class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
  65. """
  66. Request to invoke text embedding
  67. """
  68. model_type: ModelType = ModelType.TEXT_EMBEDDING
  69. texts: list[str]
  70. class RequestInvokeRerank(BaseRequestInvokeModel):
  71. """
  72. Request to invoke rerank
  73. """
  74. model_type: ModelType = ModelType.RERANK
  75. query: str
  76. docs: list[str]
  77. score_threshold: float
  78. top_n: int
  79. class RequestInvokeTTS(BaseRequestInvokeModel):
  80. """
  81. Request to invoke TTS
  82. """
  83. model_type: ModelType = ModelType.TTS
  84. content_text: str
  85. voice: str
  86. class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
  87. """
  88. Request to invoke speech2text
  89. """
  90. model_type: ModelType = ModelType.SPEECH2TEXT
  91. file: bytes
  92. @field_validator("file", mode="before")
  93. @classmethod
  94. def convert_file(cls, v):
  95. # hex string to bytes
  96. if isinstance(v, str):
  97. return bytes.fromhex(v)
  98. else:
  99. raise ValueError("file must be a hex string")
  100. class RequestInvokeModeration(BaseRequestInvokeModel):
  101. """
  102. Request to invoke moderation
  103. """
  104. model_type: ModelType = ModelType.MODERATION
  105. text: str
  106. class RequestInvokeParameterExtractorNode(BaseModel):
  107. """
  108. Request to invoke parameter extractor node
  109. """
  110. parameters: list[ParameterConfig]
  111. model: ParameterExtractorModelConfig
  112. instruction: str
  113. query: str
  114. class RequestInvokeQuestionClassifierNode(BaseModel):
  115. """
  116. Request to invoke question classifier node
  117. """
  118. query: str
  119. model: QuestionClassifierModelConfig
  120. classes: list[ClassConfig]
  121. instruction: str
  122. class RequestInvokeApp(BaseModel):
  123. """
  124. Request to invoke app
  125. """
  126. app_id: str
  127. inputs: dict[str, Any]
  128. query: Optional[str] = None
  129. response_mode: Literal["blocking", "streaming"]
  130. conversation_id: Optional[str] = None
  131. user: Optional[str] = None
  132. files: list[dict] = Field(default_factory=list)
  133. class RequestInvokeEncrypt(BaseModel):
  134. """
  135. Request to encryption
  136. """
  137. opt: Literal["encrypt", "decrypt", "clear"]
  138. namespace: Literal["endpoint"]
  139. identity: str
  140. data: dict = Field(default_factory=dict)
  141. config: list[BasicProviderConfig] = Field(default_factory=list)