request.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from typing import Any, Literal, Optional
  2. from pydantic import BaseModel, Field, field_validator
  3. from core.model_runtime.entities.message_entities import (
  4. AssistantPromptMessage,
  5. PromptMessage,
  6. PromptMessageRole,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. ToolPromptMessage,
  10. UserPromptMessage,
  11. )
  12. from core.model_runtime.entities.model_entities import ModelType
  13. class RequestInvokeTool(BaseModel):
  14. """
  15. Request to invoke a tool
  16. """
  17. class BaseRequestInvokeModel(BaseModel):
  18. provider: str
  19. model: str
  20. model_type: ModelType
  21. class RequestInvokeLLM(BaseRequestInvokeModel):
  22. """
  23. Request to invoke LLM
  24. """
  25. model_type: ModelType = ModelType.LLM
  26. mode: str
  27. model_parameters: dict[str, Any] = Field(default_factory=dict)
  28. prompt_messages: list[PromptMessage]
  29. tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
  30. stop: Optional[list[str]] = Field(default_factory=list)
  31. stream: Optional[bool] = False
  32. @field_validator('prompt_messages', mode='before')
  33. def convert_prompt_messages(cls, v):
  34. if not isinstance(v, list):
  35. raise ValueError('prompt_messages must be a list')
  36. for i in range(len(v)):
  37. if v[i]['role'] == PromptMessageRole.USER.value:
  38. v[i] = UserPromptMessage(**v[i])
  39. elif v[i]['role'] == PromptMessageRole.ASSISTANT.value:
  40. v[i] = AssistantPromptMessage(**v[i])
  41. elif v[i]['role'] == PromptMessageRole.SYSTEM.value:
  42. v[i] = SystemPromptMessage(**v[i])
  43. elif v[i]['role'] == PromptMessageRole.TOOL.value:
  44. v[i] = ToolPromptMessage(**v[i])
  45. else:
  46. v[i] = PromptMessage(**v[i])
  47. return v
  48. class RequestInvokeTextEmbedding(BaseModel):
  49. """
  50. Request to invoke text embedding
  51. """
  52. class RequestInvokeRerank(BaseModel):
  53. """
  54. Request to invoke rerank
  55. """
  56. class RequestInvokeTTS(BaseModel):
  57. """
  58. Request to invoke TTS
  59. """
  60. class RequestInvokeSpeech2Text(BaseModel):
  61. """
  62. Request to invoke speech2text
  63. """
  64. class RequestInvokeModeration(BaseModel):
  65. """
  66. Request to invoke moderation
  67. """
  68. class RequestInvokeNode(BaseModel):
  69. """
  70. Request to invoke node
  71. """
  72. class RequestInvokeApp(BaseModel):
  73. """
  74. Request to invoke app
  75. """
  76. app_id: str
  77. inputs: dict[str, Any]
  78. query: Optional[str] = None
  79. response_mode: Literal["blocking", "streaming"]
  80. conversation_id: Optional[str] = None
  81. user: Optional[str] = None
  82. files: list[dict] = Field(default_factory=list)
  83. stream: bool = Field(default=False)