request.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from collections.abc import Mapping
  2. from typing import Any, Literal, Optional
  3. from pydantic import BaseModel, Field, field_validator
  4. from core.entities.provider_entities import BasicProviderConfig
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessage,
  8. PromptMessageRole,
  9. PromptMessageTool,
  10. SystemPromptMessage,
  11. ToolPromptMessage,
  12. UserPromptMessage,
  13. )
  14. from core.model_runtime.entities.model_entities import ModelType
  15. from core.workflow.nodes.question_classifier.entities import (
  16. ClassConfig,
  17. ModelConfig as QuestionClassifierModelConfig,
  18. )
  19. from core.workflow.nodes.parameter_extractor.entities import (
  20. ModelConfig as ParameterExtractorModelConfig,
  21. )
  22. from core.workflow.nodes.parameter_extractor.entities import (
  23. ParameterConfig,
  24. )
  25. class RequestInvokeTool(BaseModel):
  26. """
  27. Request to invoke a tool
  28. """
  29. class BaseRequestInvokeModel(BaseModel):
  30. provider: str
  31. model: str
  32. model_type: ModelType
  33. class RequestInvokeLLM(BaseRequestInvokeModel):
  34. """
  35. Request to invoke LLM
  36. """
  37. model_type: ModelType = ModelType.LLM
  38. mode: str
  39. model_parameters: dict[str, Any] = Field(default_factory=dict)
  40. prompt_messages: list[PromptMessage] = Field(default_factory=list)
  41. tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
  42. stop: Optional[list[str]] = Field(default_factory=list)
  43. stream: Optional[bool] = False
  44. @field_validator("prompt_messages", mode="before")
  45. @classmethod
  46. def convert_prompt_messages(cls, v):
  47. if not isinstance(v, list):
  48. raise ValueError("prompt_messages must be a list")
  49. for i in range(len(v)):
  50. if v[i]["role"] == PromptMessageRole.USER.value:
  51. v[i] = UserPromptMessage(**v[i])
  52. elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
  53. v[i] = AssistantPromptMessage(**v[i])
  54. elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
  55. v[i] = SystemPromptMessage(**v[i])
  56. elif v[i]["role"] == PromptMessageRole.TOOL.value:
  57. v[i] = ToolPromptMessage(**v[i])
  58. else:
  59. v[i] = PromptMessage(**v[i])
  60. return v
  61. class RequestInvokeTextEmbedding(BaseModel):
  62. """
  63. Request to invoke text embedding
  64. """
  65. class RequestInvokeRerank(BaseModel):
  66. """
  67. Request to invoke rerank
  68. """
  69. class RequestInvokeTTS(BaseModel):
  70. """
  71. Request to invoke TTS
  72. """
  73. class RequestInvokeSpeech2Text(BaseModel):
  74. """
  75. Request to invoke speech2text
  76. """
  77. class RequestInvokeModeration(BaseModel):
  78. """
  79. Request to invoke moderation
  80. """
  81. class RequestInvokeParameterExtractorNode(BaseModel):
  82. """
  83. Request to invoke parameter extractor node
  84. """
  85. parameters: list[ParameterConfig]
  86. model: ParameterExtractorModelConfig
  87. instruction: str
  88. query: str
  89. class RequestInvokeQuestionClassifierNode(BaseModel):
  90. """
  91. Request to invoke question classifier node
  92. """
  93. query: str
  94. model: QuestionClassifierModelConfig
  95. classes: list[ClassConfig]
  96. instruction: str
  97. class RequestInvokeApp(BaseModel):
  98. """
  99. Request to invoke app
  100. """
  101. app_id: str
  102. inputs: dict[str, Any]
  103. query: Optional[str] = None
  104. response_mode: Literal["blocking", "streaming"]
  105. conversation_id: Optional[str] = None
  106. user: Optional[str] = None
  107. files: list[dict] = Field(default_factory=list)
  108. class RequestInvokeEncrypt(BaseModel):
  109. """
  110. Request to encryption
  111. """
  112. opt: Literal["encrypt", "decrypt"]
  113. namespace: Literal["endpoint"]
  114. identity: str
  115. data: dict = Field(default_factory=dict)
  116. config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)