invoke_tool.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package service
  2. import (
  3. "errors"
  4. "github.com/gin-gonic/gin"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  12. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  15. )
  16. func createSession[T any](
  17. r *plugin_entities.InvokePluginRequest[T],
  18. access_type access_types.PluginAccessType,
  19. access_action access_types.PluginAccessAction,
  20. cluster_id string,
  21. ) (*session_manager.Session, error) {
  22. persistence := persistence.GetPersistence()
  23. if persistence == nil {
  24. return nil, errors.New("persistence not found")
  25. }
  26. plugin_identity := parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)
  27. runtime := plugin_manager.GetGlobalPluginManager().Get(plugin_identity)
  28. session := session_manager.NewSession(
  29. r.TenantId,
  30. r.UserId,
  31. parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion),
  32. cluster_id,
  33. access_type,
  34. access_action,
  35. runtime.Configuration(),
  36. persistence,
  37. )
  38. session.BindRuntime(runtime)
  39. return session, nil
  40. }
  41. func InvokeLLM(
  42. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM],
  43. ctx *gin.Context,
  44. max_timeout_seconds int,
  45. ) {
  46. // create session
  47. session, err := createSession(
  48. r,
  49. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  50. access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  51. ctx.GetString("cluster_id"),
  52. )
  53. if err != nil {
  54. ctx.JSON(500, gin.H{"error": err.Error()})
  55. return
  56. }
  57. defer session.Close()
  58. baseSSEService(
  59. func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
  60. return plugin_daemon.InvokeLLM(session, &r.Data)
  61. },
  62. ctx,
  63. max_timeout_seconds,
  64. )
  65. }
  66. func InvokeTextEmbedding(
  67. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding],
  68. ctx *gin.Context,
  69. max_timeout_seconds int,
  70. ) {
  71. // create session
  72. session, err := createSession(
  73. r,
  74. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  75. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  76. ctx.GetString("cluster_id"))
  77. if err != nil {
  78. ctx.JSON(500, gin.H{"error": err.Error()})
  79. return
  80. }
  81. defer session.Close()
  82. baseSSEService(
  83. func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) {
  84. return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
  85. },
  86. ctx,
  87. max_timeout_seconds,
  88. )
  89. }
  90. func InvokeRerank(
  91. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank],
  92. ctx *gin.Context,
  93. max_timeout_seconds int,
  94. ) {
  95. // create session
  96. session, err := createSession(
  97. r,
  98. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  99. access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  100. ctx.GetString("cluster_id"),
  101. )
  102. if err != nil {
  103. ctx.JSON(500, gin.H{"error": err.Error()})
  104. return
  105. }
  106. defer session.Close()
  107. baseSSEService(
  108. func() (*stream.StreamResponse[model_entities.RerankResult], error) {
  109. return plugin_daemon.InvokeRerank(session, &r.Data)
  110. },
  111. ctx,
  112. max_timeout_seconds,
  113. )
  114. }
  115. func InvokeTTS(
  116. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS],
  117. ctx *gin.Context,
  118. max_timeout_seconds int,
  119. ) {
  120. // create session
  121. session, err := createSession(
  122. r,
  123. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  124. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  125. ctx.GetString("cluster_id"),
  126. )
  127. if err != nil {
  128. ctx.JSON(500, gin.H{"error": err.Error()})
  129. return
  130. }
  131. defer session.Close()
  132. baseSSEService(
  133. func() (*stream.StreamResponse[model_entities.TTSResult], error) {
  134. return plugin_daemon.InvokeTTS(session, &r.Data)
  135. },
  136. ctx,
  137. max_timeout_seconds,
  138. )
  139. }
  140. func InvokeSpeech2Text(
  141. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text],
  142. ctx *gin.Context,
  143. max_timeout_seconds int,
  144. ) {
  145. // create session
  146. session, err := createSession(
  147. r,
  148. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  149. access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  150. ctx.GetString("cluster_id"),
  151. )
  152. if err != nil {
  153. ctx.JSON(500, gin.H{"error": err.Error()})
  154. return
  155. }
  156. defer session.Close()
  157. baseSSEService(
  158. func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) {
  159. return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
  160. },
  161. ctx,
  162. max_timeout_seconds,
  163. )
  164. }
  165. func InvokeModeration(
  166. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration],
  167. ctx *gin.Context,
  168. max_timeout_seconds int,
  169. ) {
  170. // create session
  171. session, err := createSession(
  172. r,
  173. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  174. access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  175. ctx.GetString("cluster_id"),
  176. )
  177. if err != nil {
  178. ctx.JSON(500, gin.H{"error": err.Error()})
  179. return
  180. }
  181. defer session.Close()
  182. baseSSEService(
  183. func() (*stream.StreamResponse[model_entities.ModerationResult], error) {
  184. return plugin_daemon.InvokeModeration(session, &r.Data)
  185. },
  186. ctx,
  187. max_timeout_seconds,
  188. )
  189. }
  190. func ValidateProviderCredentials(
  191. r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials],
  192. ctx *gin.Context,
  193. max_timeout_seconds int,
  194. ) {
  195. // create session
  196. session, err := createSession(
  197. r,
  198. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  199. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  200. ctx.GetString("cluster_id"),
  201. )
  202. if err != nil {
  203. ctx.JSON(500, gin.H{"error": err.Error()})
  204. return
  205. }
  206. defer session.Close()
  207. baseSSEService(
  208. func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
  209. return plugin_daemon.ValidateProviderCredentials(session, &r.Data)
  210. },
  211. ctx,
  212. max_timeout_seconds,
  213. )
  214. }
  215. func ValidateModelCredentials(
  216. r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials],
  217. ctx *gin.Context,
  218. max_timeout_seconds int,
  219. ) {
  220. // create session
  221. session, err := createSession(
  222. r,
  223. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  224. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  225. ctx.GetString("cluster_id"),
  226. )
  227. if err != nil {
  228. ctx.JSON(500, gin.H{"error": err.Error()})
  229. return
  230. }
  231. defer session.Close()
  232. baseSSEService(
  233. func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
  234. return plugin_daemon.ValidateModelCredentials(session, &r.Data)
  235. },
  236. ctx,
  237. max_timeout_seconds,
  238. )
  239. }