invoke_tool.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. package service
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  13. )
  14. func createSession[T any](
  15. r *plugin_entities.InvokePluginRequest[T],
  16. access_type access_types.PluginAccessType,
  17. access_action access_types.PluginAccessAction,
  18. cluster_id string,
  19. ) *session_manager.Session {
  20. plugin_identity := parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)
  21. runtime := plugin_manager.GetGlobalPluginManager().Get(plugin_identity)
  22. session := session_manager.NewSession(
  23. r.TenantId,
  24. r.UserId,
  25. parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion),
  26. cluster_id,
  27. access_type,
  28. access_action,
  29. runtime.Configuration(),
  30. )
  31. session.BindRuntime(runtime)
  32. return session
  33. }
  34. func InvokeLLM(
  35. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM],
  36. ctx *gin.Context,
  37. max_timeout_seconds int,
  38. ) {
  39. // create session
  40. session := createSession(
  41. r,
  42. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  43. access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  44. ctx.GetString("cluster_id"),
  45. )
  46. defer session.Close()
  47. baseSSEService(
  48. func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
  49. return plugin_daemon.InvokeLLM(session, &r.Data)
  50. },
  51. ctx,
  52. max_timeout_seconds,
  53. )
  54. }
  55. func InvokeTextEmbedding(
  56. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding],
  57. ctx *gin.Context,
  58. max_timeout_seconds int,
  59. ) {
  60. // create session
  61. session := createSession(
  62. r,
  63. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  64. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  65. ctx.GetString("cluster_id"))
  66. defer session.Close()
  67. baseSSEService(
  68. func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) {
  69. return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
  70. },
  71. ctx,
  72. max_timeout_seconds,
  73. )
  74. }
  75. func InvokeRerank(
  76. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank],
  77. ctx *gin.Context,
  78. max_timeout_seconds int,
  79. ) {
  80. // create session
  81. session := createSession(
  82. r,
  83. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  84. access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  85. ctx.GetString("cluster_id"),
  86. )
  87. defer session.Close()
  88. baseSSEService(
  89. func() (*stream.StreamResponse[model_entities.RerankResult], error) {
  90. return plugin_daemon.InvokeRerank(session, &r.Data)
  91. },
  92. ctx,
  93. max_timeout_seconds,
  94. )
  95. }
  96. func InvokeTTS(
  97. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS],
  98. ctx *gin.Context,
  99. max_timeout_seconds int,
  100. ) {
  101. // create session
  102. session := createSession(
  103. r,
  104. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  105. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  106. ctx.GetString("cluster_id"),
  107. )
  108. defer session.Close()
  109. baseSSEService(
  110. func() (*stream.StreamResponse[model_entities.TTSResult], error) {
  111. return plugin_daemon.InvokeTTS(session, &r.Data)
  112. },
  113. ctx,
  114. max_timeout_seconds,
  115. )
  116. }
  117. func InvokeSpeech2Text(
  118. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text],
  119. ctx *gin.Context,
  120. max_timeout_seconds int,
  121. ) {
  122. // create session
  123. session := createSession(
  124. r,
  125. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  126. access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  127. ctx.GetString("cluster_id"),
  128. )
  129. defer session.Close()
  130. baseSSEService(
  131. func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) {
  132. return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
  133. },
  134. ctx,
  135. max_timeout_seconds,
  136. )
  137. }
  138. func InvokeModeration(
  139. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration],
  140. ctx *gin.Context,
  141. max_timeout_seconds int,
  142. ) {
  143. // create session
  144. session := createSession(
  145. r,
  146. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  147. access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  148. ctx.GetString("cluster_id"),
  149. )
  150. defer session.Close()
  151. baseSSEService(
  152. func() (*stream.StreamResponse[model_entities.ModerationResult], error) {
  153. return plugin_daemon.InvokeModeration(session, &r.Data)
  154. },
  155. ctx,
  156. max_timeout_seconds,
  157. )
  158. }
  159. func ValidateProviderCredentials(
  160. r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials],
  161. ctx *gin.Context,
  162. max_timeout_seconds int,
  163. ) {
  164. // create session
  165. session := createSession(
  166. r,
  167. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  168. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  169. ctx.GetString("cluster_id"),
  170. )
  171. defer session.Close()
  172. baseSSEService(
  173. func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
  174. return plugin_daemon.ValidateProviderCredentials(session, &r.Data)
  175. },
  176. ctx,
  177. max_timeout_seconds,
  178. )
  179. }
  180. func ValidateModelCredentials(
  181. r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials],
  182. ctx *gin.Context,
  183. max_timeout_seconds int,
  184. ) {
  185. // create session
  186. session := createSession(
  187. r,
  188. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  189. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  190. ctx.GetString("cluster_id"),
  191. )
  192. defer session.Close()
  193. baseSSEService(
  194. func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
  195. return plugin_daemon.ValidateModelCredentials(session, &r.Data)
  196. },
  197. ctx,
  198. max_timeout_seconds,
  199. )
  200. }