request.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from collections.abc import Mapping
  2. from typing import Any, Literal, Optional
  3. from pydantic import BaseModel, Field, field_validator
  4. from core.entities.provider_entities import BasicProviderConfig
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessage,
  8. PromptMessageRole,
  9. PromptMessageTool,
  10. SystemPromptMessage,
  11. ToolPromptMessage,
  12. UserPromptMessage,
  13. )
  14. from core.model_runtime.entities.model_entities import ModelType
  15. class RequestInvokeTool(BaseModel):
  16. """
  17. Request to invoke a tool
  18. """
  19. class BaseRequestInvokeModel(BaseModel):
  20. provider: str
  21. model: str
  22. model_type: ModelType
  23. class RequestInvokeLLM(BaseRequestInvokeModel):
  24. """
  25. Request to invoke LLM
  26. """
  27. model_type: ModelType = ModelType.LLM
  28. mode: str
  29. model_parameters: dict[str, Any] = Field(default_factory=dict)
  30. prompt_messages: list[PromptMessage] = Field(default_factory=list)
  31. tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
  32. stop: Optional[list[str]] = Field(default_factory=list)
  33. stream: Optional[bool] = False
  34. @field_validator("prompt_messages", mode="before")
  35. @classmethod
  36. def convert_prompt_messages(cls, v):
  37. if not isinstance(v, list):
  38. raise ValueError("prompt_messages must be a list")
  39. for i in range(len(v)):
  40. if v[i]["role"] == PromptMessageRole.USER.value:
  41. v[i] = UserPromptMessage(**v[i])
  42. elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
  43. v[i] = AssistantPromptMessage(**v[i])
  44. elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
  45. v[i] = SystemPromptMessage(**v[i])
  46. elif v[i]["role"] == PromptMessageRole.TOOL.value:
  47. v[i] = ToolPromptMessage(**v[i])
  48. else:
  49. v[i] = PromptMessage(**v[i])
  50. return v
  51. class RequestInvokeTextEmbedding(BaseModel):
  52. """
  53. Request to invoke text embedding
  54. """
  55. class RequestInvokeRerank(BaseModel):
  56. """
  57. Request to invoke rerank
  58. """
  59. class RequestInvokeTTS(BaseModel):
  60. """
  61. Request to invoke TTS
  62. """
  63. class RequestInvokeSpeech2Text(BaseModel):
  64. """
  65. Request to invoke speech2text
  66. """
  67. class RequestInvokeModeration(BaseModel):
  68. """
  69. Request to invoke moderation
  70. """
  71. class RequestInvokeNode(BaseModel):
  72. """
  73. Request to invoke node
  74. """
  75. class RequestInvokeApp(BaseModel):
  76. """
  77. Request to invoke app
  78. """
  79. app_id: str
  80. inputs: dict[str, Any]
  81. query: Optional[str] = None
  82. response_mode: Literal["blocking", "streaming"]
  83. conversation_id: Optional[str] = None
  84. user: Optional[str] = None
  85. files: list[dict] = Field(default_factory=list)
  86. class RequestInvokeEncrypt(BaseModel):
  87. """
  88. Request to encryption
  89. """
  90. opt: Literal["encrypt", "decrypt"]
  91. namespace: Literal["endpoint"]
  92. identity: str
  93. data: dict = Field(default_factory=dict)
  94. config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)