invoke_model.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. package service
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  10. )
  11. func InvokeLLM(
  12. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM],
  13. ctx *gin.Context,
  14. max_timeout_seconds int,
  15. ) {
  16. // create session
  17. session, err := createSession(
  18. r,
  19. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  20. access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  21. ctx.GetString("cluster_id"),
  22. )
  23. if err != nil {
  24. ctx.JSON(500, gin.H{"error": err.Error()})
  25. return
  26. }
  27. defer session.Close()
  28. baseSSEService(
  29. func() (*stream.Stream[model_entities.LLMResultChunk], error) {
  30. return plugin_daemon.InvokeLLM(session, &r.Data)
  31. },
  32. ctx,
  33. max_timeout_seconds,
  34. )
  35. }
  36. func InvokeTextEmbedding(
  37. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding],
  38. ctx *gin.Context,
  39. max_timeout_seconds int,
  40. ) {
  41. // create session
  42. session, err := createSession(
  43. r,
  44. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  45. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  46. ctx.GetString("cluster_id"))
  47. if err != nil {
  48. ctx.JSON(500, gin.H{"error": err.Error()})
  49. return
  50. }
  51. defer session.Close()
  52. baseSSEService(
  53. func() (*stream.Stream[model_entities.TextEmbeddingResult], error) {
  54. return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
  55. },
  56. ctx,
  57. max_timeout_seconds,
  58. )
  59. }
  60. func InvokeRerank(
  61. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank],
  62. ctx *gin.Context,
  63. max_timeout_seconds int,
  64. ) {
  65. // create session
  66. session, err := createSession(
  67. r,
  68. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  69. access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  70. ctx.GetString("cluster_id"),
  71. )
  72. if err != nil {
  73. ctx.JSON(500, gin.H{"error": err.Error()})
  74. return
  75. }
  76. defer session.Close()
  77. baseSSEService(
  78. func() (*stream.Stream[model_entities.RerankResult], error) {
  79. return plugin_daemon.InvokeRerank(session, &r.Data)
  80. },
  81. ctx,
  82. max_timeout_seconds,
  83. )
  84. }
  85. func InvokeTTS(
  86. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS],
  87. ctx *gin.Context,
  88. max_timeout_seconds int,
  89. ) {
  90. // create session
  91. session, err := createSession(
  92. r,
  93. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  94. access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  95. ctx.GetString("cluster_id"),
  96. )
  97. if err != nil {
  98. ctx.JSON(500, gin.H{"error": err.Error()})
  99. return
  100. }
  101. defer session.Close()
  102. baseSSEService(
  103. func() (*stream.Stream[model_entities.TTSResult], error) {
  104. return plugin_daemon.InvokeTTS(session, &r.Data)
  105. },
  106. ctx,
  107. max_timeout_seconds,
  108. )
  109. }
  110. func InvokeSpeech2Text(
  111. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text],
  112. ctx *gin.Context,
  113. max_timeout_seconds int,
  114. ) {
  115. // create session
  116. session, err := createSession(
  117. r,
  118. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  119. access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  120. ctx.GetString("cluster_id"),
  121. )
  122. if err != nil {
  123. ctx.JSON(500, gin.H{"error": err.Error()})
  124. return
  125. }
  126. defer session.Close()
  127. baseSSEService(
  128. func() (*stream.Stream[model_entities.Speech2TextResult], error) {
  129. return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
  130. },
  131. ctx,
  132. max_timeout_seconds,
  133. )
  134. }
  135. func InvokeModeration(
  136. r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration],
  137. ctx *gin.Context,
  138. max_timeout_seconds int,
  139. ) {
  140. // create session
  141. session, err := createSession(
  142. r,
  143. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  144. access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  145. ctx.GetString("cluster_id"),
  146. )
  147. if err != nil {
  148. ctx.JSON(500, gin.H{"error": err.Error()})
  149. return
  150. }
  151. defer session.Close()
  152. baseSSEService(
  153. func() (*stream.Stream[model_entities.ModerationResult], error) {
  154. return plugin_daemon.InvokeModeration(session, &r.Data)
  155. },
  156. ctx,
  157. max_timeout_seconds,
  158. )
  159. }
  160. func ValidateProviderCredentials(
  161. r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials],
  162. ctx *gin.Context,
  163. max_timeout_seconds int,
  164. ) {
  165. // create session
  166. session, err := createSession(
  167. r,
  168. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  169. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  170. ctx.GetString("cluster_id"),
  171. )
  172. if err != nil {
  173. ctx.JSON(500, gin.H{"error": err.Error()})
  174. return
  175. }
  176. defer session.Close()
  177. baseSSEService(
  178. func() (*stream.Stream[model_entities.ValidateCredentialsResult], error) {
  179. return plugin_daemon.ValidateProviderCredentials(session, &r.Data)
  180. },
  181. ctx,
  182. max_timeout_seconds,
  183. )
  184. }
  185. func ValidateModelCredentials(
  186. r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials],
  187. ctx *gin.Context,
  188. max_timeout_seconds int,
  189. ) {
  190. // create session
  191. session, err := createSession(
  192. r,
  193. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  194. access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  195. ctx.GetString("cluster_id"),
  196. )
  197. if err != nil {
  198. ctx.JSON(500, gin.H{"error": err.Error()})
  199. return
  200. }
  201. defer session.Close()
  202. baseSSEService(
  203. func() (*stream.Stream[model_entities.ValidateCredentialsResult], error) {
  204. return plugin_daemon.ValidateModelCredentials(session, &r.Data)
  205. },
  206. ctx,
  207. max_timeout_seconds,
  208. )
  209. }
  210. func GetTTSModelVoices(
  211. r *plugin_entities.InvokePluginRequest[requests.RequestGetTTSModelVoices],
  212. ctx *gin.Context,
  213. max_timeout_seconds int,
  214. ) {
  215. session, err := createSession(
  216. r,
  217. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  218. access_types.PLUGIN_ACCESS_ACTION_GET_TTS_MODEL_VOICES,
  219. ctx.GetString("cluster_id"),
  220. )
  221. if err != nil {
  222. ctx.JSON(500, gin.H{"error": err.Error()})
  223. return
  224. }
  225. defer session.Close()
  226. baseSSEService(
  227. func() (*stream.Stream[model_entities.GetTTSVoicesResponse], error) {
  228. return plugin_daemon.GetTTSModelVoices(session, &r.Data)
  229. },
  230. ctx,
  231. max_timeout_seconds,
  232. )
  233. }
  234. func GetTextEmbeddingNumTokens(
  235. r *plugin_entities.InvokePluginRequest[requests.RequestGetTextEmbeddingNumTokens],
  236. ctx *gin.Context,
  237. max_timeout_seconds int,
  238. ) {
  239. session, err := createSession(
  240. r,
  241. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  242. access_types.PLUGIN_ACCESS_ACTION_GET_TEXT_EMBEDDING_NUM_TOKENS,
  243. ctx.GetString("cluster_id"),
  244. )
  245. if err != nil {
  246. ctx.JSON(500, gin.H{"error": err.Error()})
  247. return
  248. }
  249. defer session.Close()
  250. baseSSEService(
  251. func() (*stream.Stream[model_entities.GetTextEmbeddingNumTokensResponse], error) {
  252. return plugin_daemon.GetTextEmbeddingNumTokens(session, &r.Data)
  253. },
  254. ctx,
  255. max_timeout_seconds,
  256. )
  257. }
  258. func GetAIModelSchema(
  259. r *plugin_entities.InvokePluginRequest[requests.RequestGetAIModelSchema],
  260. ctx *gin.Context,
  261. max_timeout_seconds int,
  262. ) {
  263. session, err := createSession(
  264. r,
  265. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  266. access_types.PLUGIN_ACCESS_ACTION_GET_AI_MODEL_SCHEMA,
  267. ctx.GetString("cluster_id"),
  268. )
  269. if err != nil {
  270. ctx.JSON(500, gin.H{"error": err.Error()})
  271. return
  272. }
  273. defer session.Close()
  274. baseSSEService(
  275. func() (*stream.Stream[model_entities.GetModelSchemasResponse], error) {
  276. return plugin_daemon.GetAIModelSchema(session, &r.Data)
  277. },
  278. ctx,
  279. max_timeout_seconds,
  280. )
  281. }
  282. func GetLLMNumTokens(
  283. r *plugin_entities.InvokePluginRequest[requests.RequestGetLLMNumTokens],
  284. ctx *gin.Context,
  285. max_timeout_seconds int,
  286. ) {
  287. session, err := createSession(
  288. r,
  289. access_types.PLUGIN_ACCESS_TYPE_MODEL,
  290. access_types.PLUGIN_ACCESS_ACTION_GET_LLM_NUM_TOKENS,
  291. ctx.GetString("cluster_id"),
  292. )
  293. if err != nil {
  294. ctx.JSON(500, gin.H{"error": err.Error()})
  295. return
  296. }
  297. defer session.Close()
  298. baseSSEService(
  299. func() (*stream.Stream[model_entities.LLMGetNumTokensResponse], error) {
  300. return plugin_daemon.GetLLMNumTokens(session, &r.Data)
  301. },
  302. ctx,
  303. max_timeout_seconds,
  304. )
  305. }