task.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. package backwards_invocation
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  14. )
  15. func InvokeDify(
  16. declaration *plugin_entities.PluginDeclaration,
  17. invoke_from access_types.PluginAccessType,
  18. session *session_manager.Session,
  19. writer BackwardsInvocationWriter,
  20. data []byte,
  21. ) error {
  22. // unmarshal invoke data
  23. request, err := parser.UnmarshalJsonBytes2Map(data)
  24. if err != nil {
  25. return fmt.Errorf("unmarshal invoke request failed: %s", err.Error())
  26. }
  27. if request == nil {
  28. return fmt.Errorf("invoke request is empty")
  29. }
  30. // prepare invocation arguments
  31. request_handle, err := prepareDifyInvocationArguments(session, writer, request)
  32. if err != nil {
  33. return err
  34. }
  35. if invoke_from == access_types.PLUGIN_ACCESS_TYPE_MODEL {
  36. request_handle.WriteError(fmt.Errorf("you can not invoke dify from %s", invoke_from))
  37. request_handle.EndResponse()
  38. return nil
  39. }
  40. // check permission
  41. if err := checkPermission(declaration, request_handle); err != nil {
  42. request_handle.WriteError(err)
  43. request_handle.EndResponse()
  44. return nil
  45. }
  46. // dispatch invocation task
  47. routine.Submit(func() {
  48. dispatchDifyInvocationTask(request_handle)
  49. defer request_handle.EndResponse()
  50. })
  51. return nil
  52. }
  53. var (
  54. permissionMapping = map[dify_invocation.InvokeType]map[string]any{
  55. dify_invocation.INVOKE_TYPE_TOOL: {
  56. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  57. return declaration.Resource.Permission.AllowInvokeTool()
  58. },
  59. "error": "permission denied, you need to enable tool access in plugin manifest",
  60. },
  61. dify_invocation.INVOKE_TYPE_LLM: {
  62. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  63. return declaration.Resource.Permission.AllowInvokeLLM()
  64. },
  65. "error": "permission denied, you need to enable llm access in plugin manifest",
  66. },
  67. dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: {
  68. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  69. return declaration.Resource.Permission.AllowInvokeTextEmbedding()
  70. },
  71. "error": "permission denied, you need to enable text-embedding access in plugin manifest",
  72. },
  73. dify_invocation.INVOKE_TYPE_RERANK: {
  74. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  75. return declaration.Resource.Permission.AllowInvokeRerank()
  76. },
  77. "error": "permission denied, you need to enable rerank access in plugin manifest",
  78. },
  79. dify_invocation.INVOKE_TYPE_TTS: {
  80. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  81. return declaration.Resource.Permission.AllowInvokeTTS()
  82. },
  83. "error": "permission denied, you need to enable tts access in plugin manifest",
  84. },
  85. dify_invocation.INVOKE_TYPE_SPEECH2TEXT: {
  86. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  87. return declaration.Resource.Permission.AllowInvokeSpeech2Text()
  88. },
  89. "error": "permission denied, you need to enable speech2text access in plugin manifest",
  90. },
  91. dify_invocation.INVOKE_TYPE_MODERATION: {
  92. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  93. return declaration.Resource.Permission.AllowInvokeModeration()
  94. },
  95. "error": "permission denied, you need to enable moderation access in plugin manifest",
  96. },
  97. dify_invocation.INVOKE_TYPE_NODE: {
  98. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  99. return declaration.Resource.Permission.AllowInvokeNode()
  100. },
  101. "error": "permission denied, you need to enable node access in plugin manifest",
  102. },
  103. dify_invocation.INVOKE_TYPE_APP: {
  104. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  105. return declaration.Resource.Permission.AllowInvokeApp()
  106. },
  107. "error": "permission denied, you need to enable app access in plugin manifest",
  108. },
  109. }
  110. )
  111. func checkPermission(runtime *plugin_entities.PluginDeclaration, request_handle *BackwardsInvocation) error {
  112. permission, ok := permissionMapping[request_handle.Type()]
  113. if !ok {
  114. return fmt.Errorf("unsupported invoke type: %s", request_handle.Type())
  115. }
  116. permission_func, ok := permission["func"].(func(runtime *plugin_entities.PluginDeclaration) bool)
  117. if !ok {
  118. return fmt.Errorf("permission function not found: %s", request_handle.Type())
  119. }
  120. if !permission_func(runtime) {
  121. return fmt.Errorf(permission["error"].(string))
  122. }
  123. return nil
  124. }
  125. func prepareDifyInvocationArguments(
  126. session *session_manager.Session,
  127. writer BackwardsInvocationWriter,
  128. request map[string]any,
  129. ) (*BackwardsInvocation, error) {
  130. typ, ok := request["type"].(string)
  131. if !ok {
  132. return nil, fmt.Errorf("invoke request missing type: %s", request)
  133. }
  134. // get request id
  135. backwards_request_id, ok := request["backwards_request_id"].(string)
  136. if !ok {
  137. return nil, fmt.Errorf("invoke request missing request_id: %s", request)
  138. }
  139. // get request
  140. detailed_request, ok := request["request"].(map[string]any)
  141. if !ok {
  142. return nil, fmt.Errorf("invoke request missing request: %s", request)
  143. }
  144. return NewBackwardsInvocation(
  145. BackwardsInvocationType(typ),
  146. backwards_request_id,
  147. session,
  148. writer,
  149. detailed_request,
  150. ), nil
  151. }
  152. var (
  153. dispatchMapping = map[dify_invocation.InvokeType]func(handle *BackwardsInvocation){
  154. dify_invocation.INVOKE_TYPE_TOOL: func(handle *BackwardsInvocation) {
  155. genericDispatchTask(handle, executeDifyInvocationToolTask)
  156. },
  157. dify_invocation.INVOKE_TYPE_LLM: func(handle *BackwardsInvocation) {
  158. genericDispatchTask(handle, executeDifyInvocationLLMTask)
  159. },
  160. dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: func(handle *BackwardsInvocation) {
  161. genericDispatchTask(handle, executeDifyInvocationTextEmbeddingTask)
  162. },
  163. dify_invocation.INVOKE_TYPE_RERANK: func(handle *BackwardsInvocation) {
  164. genericDispatchTask(handle, executeDifyInvocationRerankTask)
  165. },
  166. dify_invocation.INVOKE_TYPE_TTS: func(handle *BackwardsInvocation) {
  167. genericDispatchTask(handle, executeDifyInvocationTTSTask)
  168. },
  169. dify_invocation.INVOKE_TYPE_SPEECH2TEXT: func(handle *BackwardsInvocation) {
  170. genericDispatchTask(handle, executeDifyInvocationSpeech2TextTask)
  171. },
  172. dify_invocation.INVOKE_TYPE_MODERATION: func(handle *BackwardsInvocation) {
  173. genericDispatchTask(handle, executeDifyInvocationModerationTask)
  174. },
  175. dify_invocation.INVOKE_TYPE_APP: func(handle *BackwardsInvocation) {
  176. genericDispatchTask(handle, executeDifyInvocationAppTask)
  177. },
  178. dify_invocation.INVOKE_TYPE_STORAGE: func(handle *BackwardsInvocation) {
  179. genericDispatchTask(handle, executeDifyInvocationStorageTask)
  180. },
  181. }
  182. )
  183. func genericDispatchTask[T any](
  184. handle *BackwardsInvocation,
  185. dispatch func(
  186. handle *BackwardsInvocation,
  187. request *T,
  188. ),
  189. ) {
  190. r, err := parser.MapToStruct[T](handle.RequestData())
  191. if err != nil {
  192. handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
  193. return
  194. }
  195. dispatch(handle, r)
  196. }
  197. func dispatchDifyInvocationTask(handle *BackwardsInvocation) {
  198. request_data := handle.RequestData()
  199. tenant_id, err := handle.TenantID()
  200. if err != nil {
  201. handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
  202. return
  203. }
  204. request_data["tenant_id"] = tenant_id
  205. user_id, err := handle.UserID()
  206. if err != nil {
  207. handle.WriteError(fmt.Errorf("get user id failed: %s", err.Error()))
  208. return
  209. }
  210. request_data["user_id"] = user_id
  211. typ := handle.Type()
  212. request_data["type"] = typ
  213. for t, v := range dispatchMapping {
  214. if t == handle.Type() {
  215. v(handle)
  216. return
  217. }
  218. }
  219. handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
  220. }
  221. func executeDifyInvocationToolTask(
  222. handle *BackwardsInvocation,
  223. request *dify_invocation.InvokeToolRequest,
  224. ) {
  225. response, err := dify_invocation.InvokeTool(request)
  226. if err != nil {
  227. handle.WriteError(fmt.Errorf("invoke tool failed: %s", err.Error()))
  228. return
  229. }
  230. response.Wrap(func(t tool_entities.ToolResponseChunk) {
  231. handle.WriteResponse("stream", t)
  232. })
  233. }
  234. func executeDifyInvocationLLMTask(
  235. handle *BackwardsInvocation,
  236. request *dify_invocation.InvokeLLMRequest,
  237. ) {
  238. response, err := dify_invocation.InvokeLLM(request)
  239. if err != nil {
  240. handle.WriteError(fmt.Errorf("invoke llm model failed: %s", err.Error()))
  241. return
  242. }
  243. response.Wrap(func(t model_entities.LLMResultChunk) {
  244. handle.WriteResponse("stream", t)
  245. })
  246. }
  247. func executeDifyInvocationTextEmbeddingTask(
  248. handle *BackwardsInvocation,
  249. request *dify_invocation.InvokeTextEmbeddingRequest,
  250. ) {
  251. response, err := dify_invocation.InvokeTextEmbedding(request)
  252. if err != nil {
  253. handle.WriteError(fmt.Errorf("invoke text-embedding model failed: %s", err.Error()))
  254. return
  255. }
  256. handle.WriteResponse("struct", response)
  257. }
  258. func executeDifyInvocationRerankTask(
  259. handle *BackwardsInvocation,
  260. request *dify_invocation.InvokeRerankRequest,
  261. ) {
  262. response, err := dify_invocation.InvokeRerank(request)
  263. if err != nil {
  264. handle.WriteError(fmt.Errorf("invoke rerank model failed: %s", err.Error()))
  265. return
  266. }
  267. handle.WriteResponse("struct", response)
  268. }
  269. func executeDifyInvocationTTSTask(
  270. handle *BackwardsInvocation,
  271. request *dify_invocation.InvokeTTSRequest,
  272. ) {
  273. response, err := dify_invocation.InvokeTTS(request)
  274. if err != nil {
  275. handle.WriteError(fmt.Errorf("invoke tts model failed: %s", err.Error()))
  276. return
  277. }
  278. response.Wrap(func(t model_entities.TTSResult) {
  279. handle.WriteResponse("struct", t)
  280. })
  281. }
  282. func executeDifyInvocationSpeech2TextTask(
  283. handle *BackwardsInvocation,
  284. request *dify_invocation.InvokeSpeech2TextRequest,
  285. ) {
  286. response, err := dify_invocation.InvokeSpeech2Text(request)
  287. if err != nil {
  288. handle.WriteError(fmt.Errorf("invoke speech2text model failed: %s", err.Error()))
  289. return
  290. }
  291. handle.WriteResponse("struct", response)
  292. }
  293. func executeDifyInvocationModerationTask(
  294. handle *BackwardsInvocation,
  295. request *dify_invocation.InvokeModerationRequest,
  296. ) {
  297. response, err := dify_invocation.InvokeModeration(request)
  298. if err != nil {
  299. handle.WriteError(fmt.Errorf("invoke moderation model failed: %s", err.Error()))
  300. return
  301. }
  302. handle.WriteResponse("struct", response)
  303. }
  304. func executeDifyInvocationAppTask(
  305. handle *BackwardsInvocation,
  306. request *dify_invocation.InvokeAppRequest,
  307. ) {
  308. response, err := dify_invocation.InvokeApp(request)
  309. if err != nil {
  310. handle.WriteError(fmt.Errorf("invoke app failed: %s", err.Error()))
  311. return
  312. }
  313. user_id, err := handle.UserID()
  314. if err != nil {
  315. handle.WriteError(fmt.Errorf("get user id failed: %s", err.Error()))
  316. return
  317. }
  318. request.User = user_id
  319. response.Wrap(func(t map[string]any) {
  320. handle.WriteResponse("stream", t)
  321. })
  322. }
  323. func executeDifyInvocationStorageTask(
  324. handle *BackwardsInvocation,
  325. request *dify_invocation.InvokeStorageRequest,
  326. ) {
  327. if handle.session == nil {
  328. handle.WriteError(fmt.Errorf("session not found"))
  329. return
  330. }
  331. persistence := persistence.GetPersistence()
  332. if persistence == nil {
  333. handle.WriteError(fmt.Errorf("persistence not found"))
  334. return
  335. }
  336. tenant_id, err := handle.TenantID()
  337. if err != nil {
  338. handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
  339. return
  340. }
  341. plugin_id := handle.session.PluginIdentity
  342. if request.Opt == dify_invocation.STORAGE_OPT_GET {
  343. data, err := persistence.Load(tenant_id, plugin_id, request.Key)
  344. if err != nil {
  345. handle.WriteError(fmt.Errorf("load data failed: %s", err.Error()))
  346. return
  347. }
  348. handle.WriteResponse("struct", map[string]any{
  349. "data": hex.EncodeToString(data),
  350. })
  351. } else if request.Opt == dify_invocation.STORAGE_OPT_SET {
  352. data, err := hex.DecodeString(request.Value)
  353. if err != nil {
  354. handle.WriteError(fmt.Errorf("decode data failed: %s", err.Error()))
  355. return
  356. }
  357. if err := persistence.Save(tenant_id, plugin_id, request.Key, data); err != nil {
  358. handle.WriteError(fmt.Errorf("save data failed: %s", err.Error()))
  359. return
  360. }
  361. handle.WriteResponse("struct", map[string]any{
  362. "data": "ok",
  363. })
  364. } else if request.Opt == dify_invocation.STORAGE_OPT_DEL {
  365. if err := persistence.Delete(tenant_id, plugin_id, request.Key); err != nil {
  366. handle.WriteError(fmt.Errorf("delete data failed: %s", err.Error()))
  367. return
  368. }
  369. handle.WriteResponse("struct", map[string]any{
  370. "data": "ok",
  371. })
  372. }
  373. }