invoke_tool.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. package service
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. )
  13. func createSession[T any](
  14. r *plugin_entities.InvokePluginRequest[T],
  15. access_type access_types.PluginAccessType,
  16. access_action access_types.PluginAccessAction,
  17. cluster_id string,
  18. ) (*session_manager.Session, error) {
  19. runtime := plugin_manager.GetGlobalPluginManager().Get(r.PluginUniqueIdentifier)
  20. session := session_manager.NewSession(
  21. r.TenantId,
  22. r.UserId,
  23. r.PluginUniqueIdentifier,
  24. cluster_id,
  25. access_type,
  26. access_action,
  27. runtime.Configuration(),
  28. )
  29. session.BindRuntime(runtime)
  30. return session, nil
  31. }
  32. func InvokeLLM(
  33. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM],
  34. ctx *gin.Context,
  35. max_timeout_seconds int,
  36. ) {
  37. // create session
  38. session, err := createSession(
  39. r,
  40. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  41. access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  42. ctx.GetString("cluster_id"),
  43. )
  44. if err != nil {
  45. ctx.JSON(500, gin.H{"error": err.Error()})
  46. return
  47. }
  48. defer session.Close()
  49. baseSSEService(
  50. func() (*stream.Stream[model_entities.LLMResultChunk], error) {
  51. return plugin_daemon.InvokeLLM(session, &r.Data)
  52. },
  53. ctx,
  54. max_timeout_seconds,
  55. )
  56. }
  57. func InvokeTextEmbedding(
  58. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding],
  59. ctx *gin.Context,
  60. max_timeout_seconds int,
  61. ) {
  62. // create session
  63. session, err := createSession(
  64. r,
  65. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  66. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  67. ctx.GetString("cluster_id"))
  68. if err != nil {
  69. ctx.JSON(500, gin.H{"error": err.Error()})
  70. return
  71. }
  72. defer session.Close()
  73. baseSSEService(
  74. func() (*stream.Stream[model_entities.TextEmbeddingResult], error) {
  75. return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
  76. },
  77. ctx,
  78. max_timeout_seconds,
  79. )
  80. }
  81. func InvokeRerank(
  82. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank],
  83. ctx *gin.Context,
  84. max_timeout_seconds int,
  85. ) {
  86. // create session
  87. session, err := createSession(
  88. r,
  89. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  90. access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  91. ctx.GetString("cluster_id"),
  92. )
  93. if err != nil {
  94. ctx.JSON(500, gin.H{"error": err.Error()})
  95. return
  96. }
  97. defer session.Close()
  98. baseSSEService(
  99. func() (*stream.Stream[model_entities.RerankResult], error) {
  100. return plugin_daemon.InvokeRerank(session, &r.Data)
  101. },
  102. ctx,
  103. max_timeout_seconds,
  104. )
  105. }
  106. func InvokeTTS(
  107. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS],
  108. ctx *gin.Context,
  109. max_timeout_seconds int,
  110. ) {
  111. // create session
  112. session, err := createSession(
  113. r,
  114. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  115. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  116. ctx.GetString("cluster_id"),
  117. )
  118. if err != nil {
  119. ctx.JSON(500, gin.H{"error": err.Error()})
  120. return
  121. }
  122. defer session.Close()
  123. baseSSEService(
  124. func() (*stream.Stream[model_entities.TTSResult], error) {
  125. return plugin_daemon.InvokeTTS(session, &r.Data)
  126. },
  127. ctx,
  128. max_timeout_seconds,
  129. )
  130. }
  131. func InvokeSpeech2Text(
  132. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text],
  133. ctx *gin.Context,
  134. max_timeout_seconds int,
  135. ) {
  136. // create session
  137. session, err := createSession(
  138. r,
  139. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  140. access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  141. ctx.GetString("cluster_id"),
  142. )
  143. if err != nil {
  144. ctx.JSON(500, gin.H{"error": err.Error()})
  145. return
  146. }
  147. defer session.Close()
  148. baseSSEService(
  149. func() (*stream.Stream[model_entities.Speech2TextResult], error) {
  150. return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
  151. },
  152. ctx,
  153. max_timeout_seconds,
  154. )
  155. }
  156. func InvokeModeration(
  157. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration],
  158. ctx *gin.Context,
  159. max_timeout_seconds int,
  160. ) {
  161. // create session
  162. session, err := createSession(
  163. r,
  164. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  165. access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  166. ctx.GetString("cluster_id"),
  167. )
  168. if err != nil {
  169. ctx.JSON(500, gin.H{"error": err.Error()})
  170. return
  171. }
  172. defer session.Close()
  173. baseSSEService(
  174. func() (*stream.Stream[model_entities.ModerationResult], error) {
  175. return plugin_daemon.InvokeModeration(session, &r.Data)
  176. },
  177. ctx,
  178. max_timeout_seconds,
  179. )
  180. }
  181. func ValidateProviderCredentials(
  182. r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials],
  183. ctx *gin.Context,
  184. max_timeout_seconds int,
  185. ) {
  186. // create session
  187. session, err := createSession(
  188. r,
  189. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  190. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  191. ctx.GetString("cluster_id"),
  192. )
  193. if err != nil {
  194. ctx.JSON(500, gin.H{"error": err.Error()})
  195. return
  196. }
  197. defer session.Close()
  198. baseSSEService(
  199. func() (*stream.Stream[model_entities.ValidateCredentialsResult], error) {
  200. return plugin_daemon.ValidateProviderCredentials(session, &r.Data)
  201. },
  202. ctx,
  203. max_timeout_seconds,
  204. )
  205. }
  206. func ValidateModelCredentials(
  207. r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials],
  208. ctx *gin.Context,
  209. max_timeout_seconds int,
  210. ) {
  211. // create session
  212. session, err := createSession(
  213. r,
  214. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  215. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  216. ctx.GetString("cluster_id"),
  217. )
  218. if err != nil {
  219. ctx.JSON(500, gin.H{"error": err.Error()})
  220. return
  221. }
  222. defer session.Close()
  223. baseSSEService(
  224. func() (*stream.Stream[model_entities.ValidateCredentialsResult], error) {
  225. return plugin_daemon.ValidateModelCredentials(session, &r.Data)
  226. },
  227. ctx,
  228. max_timeout_seconds,
  229. )
  230. }