task.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. package backwards_invocation
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  12. )
  13. // returns error only if payload is not correct
  14. func InvokeDify(
  15. declaration *plugin_entities.PluginDeclaration,
  16. invoke_from access_types.PluginAccessType,
  17. session *session_manager.Session,
  18. writer BackwardsInvocationWriter,
  19. data []byte,
  20. ) error {
  21. // unmarshal invoke data
  22. request, err := parser.UnmarshalJsonBytes2Map(data)
  23. if err != nil {
  24. return fmt.Errorf("unmarshal invoke request failed: %s", err.Error())
  25. }
  26. if request == nil {
  27. return fmt.Errorf("invoke request is empty")
  28. }
  29. // prepare invocation arguments
  30. request_handle, err := prepareDifyInvocationArguments(
  31. session,
  32. writer,
  33. request,
  34. )
  35. if err != nil {
  36. return err
  37. }
  38. if invoke_from == access_types.PLUGIN_ACCESS_TYPE_MODEL {
  39. request_handle.WriteError(fmt.Errorf("you can not invoke dify from %s", invoke_from))
  40. request_handle.EndResponse()
  41. return nil
  42. }
  43. // check permission
  44. if err := checkPermission(declaration, request_handle); err != nil {
  45. request_handle.WriteError(err)
  46. request_handle.EndResponse()
  47. return nil
  48. }
  49. // dispatch invocation task
  50. routine.Submit(func() {
  51. dispatchDifyInvocationTask(request_handle)
  52. defer request_handle.EndResponse()
  53. })
  54. return nil
  55. }
  56. var (
  57. permissionMapping = map[dify_invocation.InvokeType]map[string]any{
  58. dify_invocation.INVOKE_TYPE_TOOL: {
  59. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  60. return declaration.Resource.Permission.AllowInvokeTool()
  61. },
  62. "error": "permission denied, you need to enable tool access in plugin manifest",
  63. },
  64. dify_invocation.INVOKE_TYPE_LLM: {
  65. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  66. return declaration.Resource.Permission.AllowInvokeLLM()
  67. },
  68. "error": "permission denied, you need to enable llm access in plugin manifest",
  69. },
  70. dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: {
  71. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  72. return declaration.Resource.Permission.AllowInvokeTextEmbedding()
  73. },
  74. "error": "permission denied, you need to enable text-embedding access in plugin manifest",
  75. },
  76. dify_invocation.INVOKE_TYPE_RERANK: {
  77. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  78. return declaration.Resource.Permission.AllowInvokeRerank()
  79. },
  80. "error": "permission denied, you need to enable rerank access in plugin manifest",
  81. },
  82. dify_invocation.INVOKE_TYPE_TTS: {
  83. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  84. return declaration.Resource.Permission.AllowInvokeTTS()
  85. },
  86. "error": "permission denied, you need to enable tts access in plugin manifest",
  87. },
  88. dify_invocation.INVOKE_TYPE_SPEECH2TEXT: {
  89. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  90. return declaration.Resource.Permission.AllowInvokeSpeech2Text()
  91. },
  92. "error": "permission denied, you need to enable speech2text access in plugin manifest",
  93. },
  94. dify_invocation.INVOKE_TYPE_MODERATION: {
  95. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  96. return declaration.Resource.Permission.AllowInvokeModeration()
  97. },
  98. "error": "permission denied, you need to enable moderation access in plugin manifest",
  99. },
  100. dify_invocation.INVOKE_TYPE_NODE_PARAMETER_EXTRACTOR: {
  101. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  102. return declaration.Resource.Permission.AllowInvokeNode()
  103. },
  104. "error": "permission denied, you need to enable node access in plugin manifest",
  105. },
  106. dify_invocation.INVOKE_TYPE_NODE_QUESTION_CLASSIFIER: {
  107. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  108. return declaration.Resource.Permission.AllowInvokeNode()
  109. },
  110. "error": "permission denied, you need to enable node access in plugin manifest",
  111. },
  112. dify_invocation.INVOKE_TYPE_APP: {
  113. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  114. return declaration.Resource.Permission.AllowInvokeApp()
  115. },
  116. "error": "permission denied, you need to enable app access in plugin manifest",
  117. },
  118. dify_invocation.INVOKE_TYPE_STORAGE: {
  119. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  120. return declaration.Resource.Permission.AllowInvokeStorage()
  121. },
  122. "error": "permission denied, you need to enable storage access in plugin manifest",
  123. },
  124. }
  125. )
  126. func checkPermission(runtime *plugin_entities.PluginDeclaration, request_handle *BackwardsInvocation) error {
  127. permission, ok := permissionMapping[request_handle.Type()]
  128. if !ok {
  129. return fmt.Errorf("unsupported invoke type: %s", request_handle.Type())
  130. }
  131. permission_func, ok := permission["func"].(func(runtime *plugin_entities.PluginDeclaration) bool)
  132. if !ok {
  133. return fmt.Errorf("permission function not found: %s", request_handle.Type())
  134. }
  135. if !permission_func(runtime) {
  136. return fmt.Errorf(permission["error"].(string))
  137. }
  138. return nil
  139. }
  140. func prepareDifyInvocationArguments(
  141. session *session_manager.Session,
  142. writer BackwardsInvocationWriter,
  143. request map[string]any,
  144. ) (*BackwardsInvocation, error) {
  145. typ, ok := request["type"].(string)
  146. if !ok {
  147. return nil, fmt.Errorf("invoke request missing type: %s", request)
  148. }
  149. // get request id
  150. backwards_request_id, ok := request["backwards_request_id"].(string)
  151. if !ok {
  152. return nil, fmt.Errorf("invoke request missing request_id: %s", request)
  153. }
  154. // get request
  155. detailed_request, ok := request["request"].(map[string]any)
  156. if !ok {
  157. return nil, fmt.Errorf("invoke request missing request: %s", request)
  158. }
  159. return NewBackwardsInvocation(
  160. BackwardsInvocationType(typ),
  161. backwards_request_id,
  162. session,
  163. writer,
  164. detailed_request,
  165. ), nil
  166. }
  167. var (
  168. dispatchMapping = map[dify_invocation.InvokeType]func(handle *BackwardsInvocation){
  169. dify_invocation.INVOKE_TYPE_TOOL: func(handle *BackwardsInvocation) {
  170. genericDispatchTask(handle, executeDifyInvocationToolTask)
  171. },
  172. dify_invocation.INVOKE_TYPE_LLM: func(handle *BackwardsInvocation) {
  173. genericDispatchTask(handle, executeDifyInvocationLLMTask)
  174. },
  175. dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: func(handle *BackwardsInvocation) {
  176. genericDispatchTask(handle, executeDifyInvocationTextEmbeddingTask)
  177. },
  178. dify_invocation.INVOKE_TYPE_RERANK: func(handle *BackwardsInvocation) {
  179. genericDispatchTask(handle, executeDifyInvocationRerankTask)
  180. },
  181. dify_invocation.INVOKE_TYPE_TTS: func(handle *BackwardsInvocation) {
  182. genericDispatchTask(handle, executeDifyInvocationTTSTask)
  183. },
  184. dify_invocation.INVOKE_TYPE_SPEECH2TEXT: func(handle *BackwardsInvocation) {
  185. genericDispatchTask(handle, executeDifyInvocationSpeech2TextTask)
  186. },
  187. dify_invocation.INVOKE_TYPE_MODERATION: func(handle *BackwardsInvocation) {
  188. genericDispatchTask(handle, executeDifyInvocationModerationTask)
  189. },
  190. dify_invocation.INVOKE_TYPE_APP: func(handle *BackwardsInvocation) {
  191. genericDispatchTask(handle, executeDifyInvocationAppTask)
  192. },
  193. dify_invocation.INVOKE_TYPE_NODE_PARAMETER_EXTRACTOR: func(handle *BackwardsInvocation) {
  194. genericDispatchTask(handle, executeDifyInvocationParameterExtractor)
  195. },
  196. dify_invocation.INVOKE_TYPE_NODE_QUESTION_CLASSIFIER: func(handle *BackwardsInvocation) {
  197. genericDispatchTask(handle, executeDifyInvocationQuestionClassifier)
  198. },
  199. dify_invocation.INVOKE_TYPE_STORAGE: func(handle *BackwardsInvocation) {
  200. genericDispatchTask(handle, executeDifyInvocationStorageTask)
  201. },
  202. }
  203. )
  204. func genericDispatchTask[T any](
  205. handle *BackwardsInvocation,
  206. dispatch func(
  207. handle *BackwardsInvocation,
  208. request *T,
  209. ),
  210. ) {
  211. r, err := parser.MapToStruct[T](handle.RequestData())
  212. if err != nil {
  213. handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
  214. return
  215. }
  216. dispatch(handle, r)
  217. }
  218. func dispatchDifyInvocationTask(handle *BackwardsInvocation) {
  219. request_data := handle.RequestData()
  220. tenant_id, err := handle.TenantID()
  221. if err != nil {
  222. handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
  223. return
  224. }
  225. request_data["tenant_id"] = tenant_id
  226. user_id, err := handle.UserID()
  227. if err != nil {
  228. handle.WriteError(fmt.Errorf("get user id failed: %s", err.Error()))
  229. return
  230. }
  231. request_data["user_id"] = user_id
  232. typ := handle.Type()
  233. request_data["type"] = typ
  234. for t, v := range dispatchMapping {
  235. if t == handle.Type() {
  236. v(handle)
  237. return
  238. }
  239. }
  240. handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
  241. }
  242. func executeDifyInvocationToolTask(
  243. handle *BackwardsInvocation,
  244. request *dify_invocation.InvokeToolRequest,
  245. ) {
  246. response, err := handle.backwardsInvocation.InvokeTool(request)
  247. if err != nil {
  248. handle.WriteError(fmt.Errorf("invoke tool failed: %s", err.Error()))
  249. return
  250. }
  251. for response.Next() {
  252. value, err := response.Read()
  253. if err != nil {
  254. handle.WriteError(fmt.Errorf("read tool response failed: %s", err.Error()))
  255. return
  256. }
  257. handle.WriteResponse("stream", value)
  258. }
  259. }
  260. func executeDifyInvocationLLMTask(
  261. handle *BackwardsInvocation,
  262. request *dify_invocation.InvokeLLMRequest,
  263. ) {
  264. response, err := handle.backwardsInvocation.InvokeLLM(request)
  265. if err != nil {
  266. handle.WriteError(fmt.Errorf("invoke llm model failed: %s", err.Error()))
  267. return
  268. }
  269. for response.Next() {
  270. value, err := response.Read()
  271. if err != nil {
  272. handle.WriteError(fmt.Errorf("read llm model failed: %s", err.Error()))
  273. return
  274. }
  275. handle.WriteResponse("stream", value)
  276. }
  277. }
  278. func executeDifyInvocationTextEmbeddingTask(
  279. handle *BackwardsInvocation,
  280. request *dify_invocation.InvokeTextEmbeddingRequest,
  281. ) {
  282. response, err := handle.backwardsInvocation.InvokeTextEmbedding(request)
  283. if err != nil {
  284. handle.WriteError(fmt.Errorf("invoke text-embedding model failed: %s", err.Error()))
  285. return
  286. }
  287. handle.WriteResponse("struct", response)
  288. }
  289. func executeDifyInvocationRerankTask(
  290. handle *BackwardsInvocation,
  291. request *dify_invocation.InvokeRerankRequest,
  292. ) {
  293. response, err := handle.backwardsInvocation.InvokeRerank(request)
  294. if err != nil {
  295. handle.WriteError(fmt.Errorf("invoke rerank model failed: %s", err.Error()))
  296. return
  297. }
  298. handle.WriteResponse("struct", response)
  299. }
  300. func executeDifyInvocationTTSTask(
  301. handle *BackwardsInvocation,
  302. request *dify_invocation.InvokeTTSRequest,
  303. ) {
  304. response, err := handle.backwardsInvocation.InvokeTTS(request)
  305. if err != nil {
  306. handle.WriteError(fmt.Errorf("invoke tts model failed: %s", err.Error()))
  307. return
  308. }
  309. for response.Next() {
  310. value, err := response.Read()
  311. if err != nil {
  312. handle.WriteError(fmt.Errorf("read tts model failed: %s", err.Error()))
  313. return
  314. }
  315. handle.WriteResponse("stream", value)
  316. }
  317. }
  318. func executeDifyInvocationSpeech2TextTask(
  319. handle *BackwardsInvocation,
  320. request *dify_invocation.InvokeSpeech2TextRequest,
  321. ) {
  322. response, err := handle.backwardsInvocation.InvokeSpeech2Text(request)
  323. if err != nil {
  324. handle.WriteError(fmt.Errorf("invoke speech2text model failed: %s", err.Error()))
  325. return
  326. }
  327. handle.WriteResponse("struct", response)
  328. }
  329. func executeDifyInvocationModerationTask(
  330. handle *BackwardsInvocation,
  331. request *dify_invocation.InvokeModerationRequest,
  332. ) {
  333. response, err := handle.backwardsInvocation.InvokeModeration(request)
  334. if err != nil {
  335. handle.WriteError(fmt.Errorf("invoke moderation model failed: %s", err.Error()))
  336. return
  337. }
  338. handle.WriteResponse("struct", response)
  339. }
  340. func executeDifyInvocationAppTask(
  341. handle *BackwardsInvocation,
  342. request *dify_invocation.InvokeAppRequest,
  343. ) {
  344. response, err := handle.backwardsInvocation.InvokeApp(request)
  345. if err != nil {
  346. handle.WriteError(fmt.Errorf("invoke app failed: %s", err.Error()))
  347. return
  348. }
  349. user_id, err := handle.UserID()
  350. if err != nil {
  351. handle.WriteError(fmt.Errorf("get user id failed: %s", err.Error()))
  352. return
  353. }
  354. request.User = user_id
  355. for response.Next() {
  356. value, err := response.Read()
  357. if err != nil {
  358. handle.WriteError(fmt.Errorf("read app failed: %s", err.Error()))
  359. return
  360. }
  361. handle.WriteResponse("stream", value)
  362. }
  363. }
  364. func executeDifyInvocationParameterExtractor(
  365. handle *BackwardsInvocation,
  366. request *dify_invocation.InvokeParameterExtractorRequest,
  367. ) {
  368. response, err := handle.backwardsInvocation.InvokeParameterExtractor(request)
  369. if err != nil {
  370. handle.WriteError(fmt.Errorf("invoke parameter extractor failed: %s", err.Error()))
  371. return
  372. }
  373. handle.WriteResponse("struct", response)
  374. }
  375. func executeDifyInvocationQuestionClassifier(
  376. handle *BackwardsInvocation,
  377. request *dify_invocation.InvokeQuestionClassifierRequest,
  378. ) {
  379. response, err := handle.backwardsInvocation.InvokeQuestionClassifier(request)
  380. if err != nil {
  381. handle.WriteError(fmt.Errorf("invoke question classifier failed: %s", err.Error()))
  382. return
  383. }
  384. handle.WriteResponse("struct", response)
  385. }
  386. func executeDifyInvocationStorageTask(
  387. handle *BackwardsInvocation,
  388. request *dify_invocation.InvokeStorageRequest,
  389. ) {
  390. if handle.session == nil {
  391. handle.WriteError(fmt.Errorf("session not found"))
  392. return
  393. }
  394. persistence := persistence.GetPersistence()
  395. if persistence == nil {
  396. handle.WriteError(fmt.Errorf("persistence not found"))
  397. return
  398. }
  399. tenant_id, err := handle.TenantID()
  400. if err != nil {
  401. handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
  402. return
  403. }
  404. plugin_id := handle.session.PluginUniqueIdentifier
  405. if request.Opt == dify_invocation.STORAGE_OPT_GET {
  406. data, err := persistence.Load(tenant_id, plugin_id.PluginID(), request.Key)
  407. if err != nil {
  408. handle.WriteError(fmt.Errorf("load data failed: %s", err.Error()))
  409. return
  410. }
  411. handle.WriteResponse("struct", map[string]any{
  412. "data": hex.EncodeToString(data),
  413. })
  414. } else if request.Opt == dify_invocation.STORAGE_OPT_SET {
  415. data, err := hex.DecodeString(request.Value)
  416. if err != nil {
  417. handle.WriteError(fmt.Errorf("decode data failed: %s", err.Error()))
  418. return
  419. }
  420. if err := persistence.Save(tenant_id, plugin_id.PluginID(), request.Key, data); err != nil {
  421. handle.WriteError(fmt.Errorf("save data failed: %s", err.Error()))
  422. return
  423. }
  424. handle.WriteResponse("struct", map[string]any{
  425. "data": "ok",
  426. })
  427. } else if request.Opt == dify_invocation.STORAGE_OPT_DEL {
  428. if err := persistence.Delete(tenant_id, plugin_id.PluginID(), request.Key); err != nil {
  429. handle.WriteError(fmt.Errorf("delete data failed: %s", err.Error()))
  430. return
  431. }
  432. handle.WriteResponse("struct", map[string]any{
  433. "data": "ok",
  434. })
  435. }
  436. }