invoke_tool.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. package service
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  13. )
  14. func createSession[T any](
  15. r *plugin_entities.InvokePluginRequest[T],
  16. access_type access_types.PluginAccessType,
  17. access_action access_types.PluginAccessAction,
  18. cluster_id string,
  19. ) (*session_manager.Session, error) {
  20. plugin_identity := parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)
  21. runtime := plugin_manager.GetGlobalPluginManager().Get(plugin_identity)
  22. session := session_manager.NewSession(
  23. r.TenantId,
  24. r.UserId,
  25. parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion),
  26. cluster_id,
  27. access_type,
  28. access_action,
  29. runtime.Configuration(),
  30. )
  31. session.BindRuntime(runtime)
  32. return session, nil
  33. }
  34. func InvokeLLM(
  35. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM],
  36. ctx *gin.Context,
  37. max_timeout_seconds int,
  38. ) {
  39. // create session
  40. session, err := createSession(
  41. r,
  42. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  43. access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  44. ctx.GetString("cluster_id"),
  45. )
  46. if err != nil {
  47. ctx.JSON(500, gin.H{"error": err.Error()})
  48. return
  49. }
  50. defer session.Close()
  51. baseSSEService(
  52. func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
  53. return plugin_daemon.InvokeLLM(session, &r.Data)
  54. },
  55. ctx,
  56. max_timeout_seconds,
  57. )
  58. }
  59. func InvokeTextEmbedding(
  60. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding],
  61. ctx *gin.Context,
  62. max_timeout_seconds int,
  63. ) {
  64. // create session
  65. session, err := createSession(
  66. r,
  67. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  68. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  69. ctx.GetString("cluster_id"))
  70. if err != nil {
  71. ctx.JSON(500, gin.H{"error": err.Error()})
  72. return
  73. }
  74. defer session.Close()
  75. baseSSEService(
  76. func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) {
  77. return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
  78. },
  79. ctx,
  80. max_timeout_seconds,
  81. )
  82. }
  83. func InvokeRerank(
  84. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank],
  85. ctx *gin.Context,
  86. max_timeout_seconds int,
  87. ) {
  88. // create session
  89. session, err := createSession(
  90. r,
  91. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  92. access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  93. ctx.GetString("cluster_id"),
  94. )
  95. if err != nil {
  96. ctx.JSON(500, gin.H{"error": err.Error()})
  97. return
  98. }
  99. defer session.Close()
  100. baseSSEService(
  101. func() (*stream.StreamResponse[model_entities.RerankResult], error) {
  102. return plugin_daemon.InvokeRerank(session, &r.Data)
  103. },
  104. ctx,
  105. max_timeout_seconds,
  106. )
  107. }
  108. func InvokeTTS(
  109. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS],
  110. ctx *gin.Context,
  111. max_timeout_seconds int,
  112. ) {
  113. // create session
  114. session, err := createSession(
  115. r,
  116. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  117. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  118. ctx.GetString("cluster_id"),
  119. )
  120. if err != nil {
  121. ctx.JSON(500, gin.H{"error": err.Error()})
  122. return
  123. }
  124. defer session.Close()
  125. baseSSEService(
  126. func() (*stream.StreamResponse[model_entities.TTSResult], error) {
  127. return plugin_daemon.InvokeTTS(session, &r.Data)
  128. },
  129. ctx,
  130. max_timeout_seconds,
  131. )
  132. }
  133. func InvokeSpeech2Text(
  134. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text],
  135. ctx *gin.Context,
  136. max_timeout_seconds int,
  137. ) {
  138. // create session
  139. session, err := createSession(
  140. r,
  141. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  142. access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  143. ctx.GetString("cluster_id"),
  144. )
  145. if err != nil {
  146. ctx.JSON(500, gin.H{"error": err.Error()})
  147. return
  148. }
  149. defer session.Close()
  150. baseSSEService(
  151. func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) {
  152. return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
  153. },
  154. ctx,
  155. max_timeout_seconds,
  156. )
  157. }
  158. func InvokeModeration(
  159. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration],
  160. ctx *gin.Context,
  161. max_timeout_seconds int,
  162. ) {
  163. // create session
  164. session, err := createSession(
  165. r,
  166. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  167. access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  168. ctx.GetString("cluster_id"),
  169. )
  170. if err != nil {
  171. ctx.JSON(500, gin.H{"error": err.Error()})
  172. return
  173. }
  174. defer session.Close()
  175. baseSSEService(
  176. func() (*stream.StreamResponse[model_entities.ModerationResult], error) {
  177. return plugin_daemon.InvokeModeration(session, &r.Data)
  178. },
  179. ctx,
  180. max_timeout_seconds,
  181. )
  182. }
  183. func ValidateProviderCredentials(
  184. r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials],
  185. ctx *gin.Context,
  186. max_timeout_seconds int,
  187. ) {
  188. // create session
  189. session, err := createSession(
  190. r,
  191. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  192. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  193. ctx.GetString("cluster_id"),
  194. )
  195. if err != nil {
  196. ctx.JSON(500, gin.H{"error": err.Error()})
  197. return
  198. }
  199. defer session.Close()
  200. baseSSEService(
  201. func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
  202. return plugin_daemon.ValidateProviderCredentials(session, &r.Data)
  203. },
  204. ctx,
  205. max_timeout_seconds,
  206. )
  207. }
  208. func ValidateModelCredentials(
  209. r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials],
  210. ctx *gin.Context,
  211. max_timeout_seconds int,
  212. ) {
  213. // create session
  214. session, err := createSession(
  215. r,
  216. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  217. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  218. ctx.GetString("cluster_id"),
  219. )
  220. if err != nil {
  221. ctx.JSON(500, gin.H{"error": err.Error()})
  222. return
  223. }
  224. defer session.Close()
  225. baseSSEService(
  226. func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
  227. return plugin_daemon.ValidateModelCredentials(session, &r.Data)
  228. },
  229. ctx,
  230. max_timeout_seconds,
  231. )
  232. }