request.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from collections.abc import Mapping
  2. from typing import Any, Literal, Optional
  3. from pydantic import BaseModel, ConfigDict, Field, field_validator
  4. from core.entities.provider_entities import BasicProviderConfig
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessage,
  8. PromptMessageRole,
  9. PromptMessageTool,
  10. SystemPromptMessage,
  11. ToolPromptMessage,
  12. UserPromptMessage,
  13. )
  14. from core.model_runtime.entities.model_entities import ModelType
  15. from core.workflow.nodes.parameter_extractor.entities import (
  16. ModelConfig as ParameterExtractorModelConfig,
  17. )
  18. from core.workflow.nodes.parameter_extractor.entities import (
  19. ParameterConfig,
  20. )
  21. from core.workflow.nodes.question_classifier.entities import (
  22. ClassConfig,
  23. )
  24. from core.workflow.nodes.question_classifier.entities import (
  25. ModelConfig as QuestionClassifierModelConfig,
  26. )
  27. class RequestInvokeTool(BaseModel):
  28. """
  29. Request to invoke a tool
  30. """
  31. class BaseRequestInvokeModel(BaseModel):
  32. provider: str
  33. model: str
  34. model_type: ModelType
  35. model_config = ConfigDict(protected_namespaces=())
  36. class RequestInvokeLLM(BaseRequestInvokeModel):
  37. """
  38. Request to invoke LLM
  39. """
  40. model_type: ModelType = ModelType.LLM
  41. mode: str
  42. model_parameters: dict[str, Any] = Field(default_factory=dict)
  43. prompt_messages: list[PromptMessage] = Field(default_factory=list)
  44. tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
  45. stop: Optional[list[str]] = Field(default_factory=list)
  46. stream: Optional[bool] = False
  47. model_config = ConfigDict(protected_namespaces=())
  48. @field_validator("prompt_messages", mode="before")
  49. @classmethod
  50. def convert_prompt_messages(cls, v):
  51. if not isinstance(v, list):
  52. raise ValueError("prompt_messages must be a list")
  53. for i in range(len(v)):
  54. if v[i]["role"] == PromptMessageRole.USER.value:
  55. v[i] = UserPromptMessage(**v[i])
  56. elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
  57. v[i] = AssistantPromptMessage(**v[i])
  58. elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
  59. v[i] = SystemPromptMessage(**v[i])
  60. elif v[i]["role"] == PromptMessageRole.TOOL.value:
  61. v[i] = ToolPromptMessage(**v[i])
  62. else:
  63. v[i] = PromptMessage(**v[i])
  64. return v
  65. class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
  66. """
  67. Request to invoke text embedding
  68. """
  69. model_type: ModelType = ModelType.TEXT_EMBEDDING
  70. texts: list[str]
  71. class RequestInvokeRerank(BaseRequestInvokeModel):
  72. """
  73. Request to invoke rerank
  74. """
  75. model_type: ModelType = ModelType.RERANK
  76. query: str
  77. docs: list[str]
  78. score_threshold: float
  79. top_n: int
  80. class RequestInvokeTTS(BaseRequestInvokeModel):
  81. """
  82. Request to invoke TTS
  83. """
  84. model_type: ModelType = ModelType.TTS
  85. content_text: str
  86. voice: str
  87. class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
  88. """
  89. Request to invoke speech2text
  90. """
  91. model_type: ModelType = ModelType.SPEECH2TEXT
  92. file: bytes
  93. @field_validator("file", mode="before")
  94. @classmethod
  95. def convert_file(cls, v):
  96. # hex string to bytes
  97. if isinstance(v, str):
  98. return bytes.fromhex(v)
  99. else:
  100. raise ValueError("file must be a hex string")
  101. class RequestInvokeModeration(BaseRequestInvokeModel):
  102. """
  103. Request to invoke moderation
  104. """
  105. model_type: ModelType = ModelType.MODERATION
  106. text: str
  107. class RequestInvokeParameterExtractorNode(BaseModel):
  108. """
  109. Request to invoke parameter extractor node
  110. """
  111. parameters: list[ParameterConfig]
  112. model: ParameterExtractorModelConfig
  113. instruction: str
  114. query: str
  115. class RequestInvokeQuestionClassifierNode(BaseModel):
  116. """
  117. Request to invoke question classifier node
  118. """
  119. query: str
  120. model: QuestionClassifierModelConfig
  121. classes: list[ClassConfig]
  122. instruction: str
  123. class RequestInvokeApp(BaseModel):
  124. """
  125. Request to invoke app
  126. """
  127. app_id: str
  128. inputs: dict[str, Any]
  129. query: Optional[str] = None
  130. response_mode: Literal["blocking", "streaming"]
  131. conversation_id: Optional[str] = None
  132. user: Optional[str] = None
  133. files: list[dict] = Field(default_factory=list)
  134. class RequestInvokeEncrypt(BaseModel):
  135. """
  136. Request to encryption
  137. """
  138. opt: Literal["encrypt", "decrypt", "clear"]
  139. namespace: Literal["endpoint"]
  140. identity: str
  141. data: dict = Field(default_factory=dict)
  142. config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)