task.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. package backwards_invocation
  2. import (
  3. "fmt"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  11. )
  12. func InvokeDify(
  13. runtime entities.PluginRuntimeInterface,
  14. invoke_from PluginAccessType,
  15. session *session_manager.Session, data []byte,
  16. ) error {
  17. // unmarshal invoke data
  18. request, err := parser.UnmarshalJsonBytes2Map(data)
  19. if err != nil {
  20. return fmt.Errorf("unmarshal invoke request failed: %s", err.Error())
  21. }
  22. if request == nil {
  23. return fmt.Errorf("invoke request is empty")
  24. }
  25. // prepare invocation arguments
  26. request_handle, err := prepareDifyInvocationArguments(session, request)
  27. if err != nil {
  28. return err
  29. }
  30. if invoke_from == PLUGIN_ACCESS_TYPE_MODEL {
  31. request_handle.WriteError(fmt.Errorf("you can not invoke dify from %s", invoke_from))
  32. request_handle.EndResponse()
  33. return nil
  34. }
  35. // check permission
  36. if err := checkPermission(runtime, request_handle); err != nil {
  37. request_handle.WriteError(err)
  38. request_handle.EndResponse()
  39. return nil
  40. }
  41. // dispatch invocation task
  42. routine.Submit(func() {
  43. dispatchDifyInvocationTask(request_handle)
  44. defer request_handle.EndResponse()
  45. })
  46. return nil
  47. }
  48. var (
  49. permissionMapping = map[dify_invocation.InvokeType]map[string]any{
  50. dify_invocation.INVOKE_TYPE_TOOL: {
  51. "func": func(runtime entities.PluginRuntimeTimeLifeInterface) bool {
  52. return runtime.Configuration().Resource.Permission.AllowInvokeTool()
  53. },
  54. "error": "permission denied, you need to enable tool access in plugin manifest",
  55. },
  56. dify_invocation.INVOKE_TYPE_LLM: {
  57. "func": func(runtime entities.PluginRuntimeTimeLifeInterface) bool {
  58. return runtime.Configuration().Resource.Permission.AllowInvokeLLM()
  59. },
  60. "error": "permission denied, you need to enable llm access in plugin manifest",
  61. },
  62. dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: {
  63. "func": func(runtime entities.PluginRuntimeTimeLifeInterface) bool {
  64. return runtime.Configuration().Resource.Permission.AllowInvokeTextEmbedding()
  65. },
  66. "error": "permission denied, you need to enable text-embedding access in plugin manifest",
  67. },
  68. dify_invocation.INVOKE_TYPE_RERANK: {
  69. "func": func(runtime entities.PluginRuntimeTimeLifeInterface) bool {
  70. return runtime.Configuration().Resource.Permission.AllowInvokeRerank()
  71. },
  72. "error": "permission denied, you need to enable rerank access in plugin manifest",
  73. },
  74. dify_invocation.INVOKE_TYPE_TTS: {
  75. "func": func(runtime entities.PluginRuntimeTimeLifeInterface) bool {
  76. return runtime.Configuration().Resource.Permission.AllowInvokeTTS()
  77. },
  78. "error": "permission denied, you need to enable tts access in plugin manifest",
  79. },
  80. dify_invocation.INVOKE_TYPE_SPEECH2TEXT: {
  81. "func": func(runtime entities.PluginRuntimeTimeLifeInterface) bool {
  82. return runtime.Configuration().Resource.Permission.AllowInvokeSpeech2Text()
  83. },
  84. "error": "permission denied, you need to enable speech2text access in plugin manifest",
  85. },
  86. dify_invocation.INVOKE_TYPE_MODERATION: {
  87. "func": func(runtime entities.PluginRuntimeTimeLifeInterface) bool {
  88. return runtime.Configuration().Resource.Permission.AllowInvokeModeration()
  89. },
  90. "error": "permission denied, you need to enable moderation access in plugin manifest",
  91. },
  92. dify_invocation.INVOKE_TYPE_NODE: {
  93. "func": func(runtime entities.PluginRuntimeTimeLifeInterface) bool {
  94. return runtime.Configuration().Resource.Permission.AllowInvokeNode()
  95. },
  96. "error": "permission denied, you need to enable node access in plugin manifest",
  97. },
  98. }
  99. )
  100. func checkPermission(runtime entities.PluginRuntimeTimeLifeInterface, request_handle *BackwardsInvocation) error {
  101. permission, ok := permissionMapping[request_handle.Type()]
  102. if !ok {
  103. return fmt.Errorf("unsupported invoke type: %s", request_handle.Type())
  104. }
  105. permission_func, ok := permission["func"].(func(runtime entities.PluginRuntimeTimeLifeInterface) bool)
  106. if !ok {
  107. return fmt.Errorf("permission function not found: %s", request_handle.Type())
  108. }
  109. if !permission_func(runtime) {
  110. return fmt.Errorf(permission["error"].(string))
  111. }
  112. return nil
  113. }
  114. func prepareDifyInvocationArguments(session *session_manager.Session, request map[string]any) (*BackwardsInvocation, error) {
  115. typ, ok := request["type"].(string)
  116. if !ok {
  117. return nil, fmt.Errorf("invoke request missing type: %s", request)
  118. }
  119. // get request id
  120. backwards_request_id, ok := request["backwards_request_id"].(string)
  121. if !ok {
  122. return nil, fmt.Errorf("invoke request missing request_id: %s", request)
  123. }
  124. // get request
  125. detailed_request, ok := request["request"].(map[string]any)
  126. if !ok {
  127. return nil, fmt.Errorf("invoke request missing request: %s", request)
  128. }
  129. return NewBackwardsInvocation(
  130. BackwardsInvocationType(typ),
  131. backwards_request_id, session, detailed_request,
  132. ), nil
  133. }
  134. var (
  135. dispatchMapping = map[dify_invocation.InvokeType]func(handle *BackwardsInvocation){
  136. dify_invocation.INVOKE_TYPE_TOOL: func(handle *BackwardsInvocation) {
  137. genericDispatchTask[dify_invocation.InvokeToolRequest](handle, executeDifyInvocationToolTask)
  138. },
  139. dify_invocation.INVOKE_TYPE_LLM: func(handle *BackwardsInvocation) {
  140. genericDispatchTask[dify_invocation.InvokeLLMRequest](handle, executeDifyInvocationLLMTask)
  141. },
  142. dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: func(handle *BackwardsInvocation) {
  143. genericDispatchTask[dify_invocation.InvokeTextEmbeddingRequest](handle, executeDifyInvocationTextEmbeddingTask)
  144. },
  145. dify_invocation.INVOKE_TYPE_RERANK: func(handle *BackwardsInvocation) {
  146. genericDispatchTask[dify_invocation.InvokeRerankRequest](handle, executeDifyInvocationRerankTask)
  147. },
  148. dify_invocation.INVOKE_TYPE_TTS: func(handle *BackwardsInvocation) {
  149. genericDispatchTask[dify_invocation.InvokeTTSRequest](handle, executeDifyInvocationTTSTask)
  150. },
  151. dify_invocation.INVOKE_TYPE_SPEECH2TEXT: func(handle *BackwardsInvocation) {
  152. genericDispatchTask[dify_invocation.InvokeSpeech2TextRequest](handle, executeDifyInvocationSpeech2TextTask)
  153. },
  154. dify_invocation.INVOKE_TYPE_MODERATION: func(handle *BackwardsInvocation) {
  155. genericDispatchTask[dify_invocation.InvokeModerationRequest](handle, executeDifyInvocationModerationTask)
  156. },
  157. }
  158. )
  159. func genericDispatchTask[T any](
  160. handle *BackwardsInvocation,
  161. dispatch func(
  162. handle *BackwardsInvocation,
  163. request *T,
  164. ),
  165. ) {
  166. r, err := parser.MapToStruct[T](handle.RequestData())
  167. if err != nil {
  168. handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
  169. return
  170. }
  171. dispatch(handle, r)
  172. }
  173. func dispatchDifyInvocationTask(handle *BackwardsInvocation) {
  174. request_data := handle.RequestData()
  175. tenant_id, err := handle.TenantID()
  176. if err != nil {
  177. handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
  178. return
  179. }
  180. request_data["tenant_id"] = tenant_id
  181. user_id, err := handle.UserID()
  182. if err != nil {
  183. handle.WriteError(fmt.Errorf("get user id failed: %s", err.Error()))
  184. return
  185. }
  186. request_data["user_id"] = user_id
  187. typ := handle.Type()
  188. request_data["type"] = typ
  189. for t, v := range dispatchMapping {
  190. if t == handle.Type() {
  191. v(handle)
  192. return
  193. }
  194. }
  195. handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
  196. }
  197. func executeDifyInvocationToolTask(
  198. handle *BackwardsInvocation,
  199. request *dify_invocation.InvokeToolRequest,
  200. ) {
  201. response, err := dify_invocation.InvokeTool(request)
  202. if err != nil {
  203. handle.WriteError(fmt.Errorf("invoke tool failed: %s", err.Error()))
  204. return
  205. }
  206. response.Wrap(func(t tool_entities.ToolResponseChunk) {
  207. handle.WriteResponse("stream", t)
  208. })
  209. }
  210. func executeDifyInvocationLLMTask(
  211. handle *BackwardsInvocation,
  212. request *dify_invocation.InvokeLLMRequest,
  213. ) {
  214. response, err := dify_invocation.InvokeLLM(request)
  215. if err != nil {
  216. handle.WriteError(fmt.Errorf("invoke llm model failed: %s", err.Error()))
  217. return
  218. }
  219. response.Wrap(func(t model_entities.LLMResultChunk) {
  220. handle.WriteResponse("stream", t)
  221. })
  222. }
  223. func executeDifyInvocationTextEmbeddingTask(
  224. handle *BackwardsInvocation,
  225. request *dify_invocation.InvokeTextEmbeddingRequest,
  226. ) {
  227. response, err := dify_invocation.InvokeTextEmbedding(request)
  228. if err != nil {
  229. handle.WriteError(fmt.Errorf("invoke text-embedding model failed: %s", err.Error()))
  230. return
  231. }
  232. handle.WriteResponse("struct", response)
  233. }
  234. func executeDifyInvocationRerankTask(
  235. handle *BackwardsInvocation,
  236. request *dify_invocation.InvokeRerankRequest,
  237. ) {
  238. response, err := dify_invocation.InvokeRerank(request)
  239. if err != nil {
  240. handle.WriteError(fmt.Errorf("invoke rerank model failed: %s", err.Error()))
  241. return
  242. }
  243. handle.WriteResponse("struct", response)
  244. }
  245. func executeDifyInvocationTTSTask(
  246. handle *BackwardsInvocation,
  247. request *dify_invocation.InvokeTTSRequest,
  248. ) {
  249. response, err := dify_invocation.InvokeTTS(request)
  250. if err != nil {
  251. handle.WriteError(fmt.Errorf("invoke tts model failed: %s", err.Error()))
  252. return
  253. }
  254. response.Wrap(func(t model_entities.TTSResult) {
  255. handle.WriteResponse("struct", t)
  256. })
  257. }
  258. func executeDifyInvocationSpeech2TextTask(
  259. handle *BackwardsInvocation,
  260. request *dify_invocation.InvokeSpeech2TextRequest,
  261. ) {
  262. response, err := dify_invocation.InvokeSpeech2Text(request)
  263. if err != nil {
  264. handle.WriteError(fmt.Errorf("invoke speech2text model failed: %s", err.Error()))
  265. return
  266. }
  267. handle.WriteResponse("struct", response)
  268. }
  269. func executeDifyInvocationModerationTask(
  270. handle *BackwardsInvocation,
  271. request *dify_invocation.InvokeModerationRequest,
  272. ) {
  273. response, err := dify_invocation.InvokeModeration(request)
  274. if err != nil {
  275. handle.WriteError(fmt.Errorf("invoke moderation model failed: %s", err.Error()))
  276. return
  277. }
  278. handle.WriteResponse("struct", response)
  279. }