task.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. package backwards_invocation
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  14. )
  15. // returns error only if payload is not correct
  16. func InvokeDify(
  17. declaration *plugin_entities.PluginDeclaration,
  18. invoke_from access_types.PluginAccessType,
  19. session *session_manager.Session,
  20. writer BackwardsInvocationWriter,
  21. data []byte,
  22. ) error {
  23. // unmarshal invoke data
  24. request, err := parser.UnmarshalJsonBytes2Map(data)
  25. if err != nil {
  26. return fmt.Errorf("unmarshal invoke request failed: %s", err.Error())
  27. }
  28. if request == nil {
  29. return fmt.Errorf("invoke request is empty")
  30. }
  31. // prepare invocation arguments
  32. request_handle, err := prepareDifyInvocationArguments(session, writer, request)
  33. if err != nil {
  34. return err
  35. }
  36. if invoke_from == access_types.PLUGIN_ACCESS_TYPE_MODEL {
  37. request_handle.WriteError(fmt.Errorf("you can not invoke dify from %s", invoke_from))
  38. request_handle.EndResponse()
  39. return nil
  40. }
  41. // check permission
  42. if err := checkPermission(declaration, request_handle); err != nil {
  43. request_handle.WriteError(err)
  44. request_handle.EndResponse()
  45. return nil
  46. }
  47. // dispatch invocation task
  48. routine.Submit(func() {
  49. dispatchDifyInvocationTask(request_handle)
  50. defer request_handle.EndResponse()
  51. })
  52. return nil
  53. }
  54. var (
  55. permissionMapping = map[dify_invocation.InvokeType]map[string]any{
  56. dify_invocation.INVOKE_TYPE_TOOL: {
  57. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  58. return declaration.Resource.Permission.AllowInvokeTool()
  59. },
  60. "error": "permission denied, you need to enable tool access in plugin manifest",
  61. },
  62. dify_invocation.INVOKE_TYPE_LLM: {
  63. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  64. return declaration.Resource.Permission.AllowInvokeLLM()
  65. },
  66. "error": "permission denied, you need to enable llm access in plugin manifest",
  67. },
  68. dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: {
  69. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  70. return declaration.Resource.Permission.AllowInvokeTextEmbedding()
  71. },
  72. "error": "permission denied, you need to enable text-embedding access in plugin manifest",
  73. },
  74. dify_invocation.INVOKE_TYPE_RERANK: {
  75. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  76. return declaration.Resource.Permission.AllowInvokeRerank()
  77. },
  78. "error": "permission denied, you need to enable rerank access in plugin manifest",
  79. },
  80. dify_invocation.INVOKE_TYPE_TTS: {
  81. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  82. return declaration.Resource.Permission.AllowInvokeTTS()
  83. },
  84. "error": "permission denied, you need to enable tts access in plugin manifest",
  85. },
  86. dify_invocation.INVOKE_TYPE_SPEECH2TEXT: {
  87. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  88. return declaration.Resource.Permission.AllowInvokeSpeech2Text()
  89. },
  90. "error": "permission denied, you need to enable speech2text access in plugin manifest",
  91. },
  92. dify_invocation.INVOKE_TYPE_MODERATION: {
  93. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  94. return declaration.Resource.Permission.AllowInvokeModeration()
  95. },
  96. "error": "permission denied, you need to enable moderation access in plugin manifest",
  97. },
  98. dify_invocation.INVOKE_TYPE_NODE: {
  99. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  100. return declaration.Resource.Permission.AllowInvokeNode()
  101. },
  102. "error": "permission denied, you need to enable node access in plugin manifest",
  103. },
  104. dify_invocation.INVOKE_TYPE_APP: {
  105. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  106. return declaration.Resource.Permission.AllowInvokeApp()
  107. },
  108. "error": "permission denied, you need to enable app access in plugin manifest",
  109. },
  110. dify_invocation.INVOKE_TYPE_STORAGE: {
  111. "func": func(declaration *plugin_entities.PluginDeclaration) bool {
  112. return declaration.Resource.Permission.AllowInvokeStorage()
  113. },
  114. "error": "permission denied, you need to enable storage access in plugin manifest",
  115. },
  116. }
  117. )
  118. func checkPermission(runtime *plugin_entities.PluginDeclaration, request_handle *BackwardsInvocation) error {
  119. permission, ok := permissionMapping[request_handle.Type()]
  120. if !ok {
  121. return fmt.Errorf("unsupported invoke type: %s", request_handle.Type())
  122. }
  123. permission_func, ok := permission["func"].(func(runtime *plugin_entities.PluginDeclaration) bool)
  124. if !ok {
  125. return fmt.Errorf("permission function not found: %s", request_handle.Type())
  126. }
  127. if !permission_func(runtime) {
  128. return fmt.Errorf(permission["error"].(string))
  129. }
  130. return nil
  131. }
  132. func prepareDifyInvocationArguments(
  133. session *session_manager.Session,
  134. writer BackwardsInvocationWriter,
  135. request map[string]any,
  136. ) (*BackwardsInvocation, error) {
  137. typ, ok := request["type"].(string)
  138. if !ok {
  139. return nil, fmt.Errorf("invoke request missing type: %s", request)
  140. }
  141. // get request id
  142. backwards_request_id, ok := request["backwards_request_id"].(string)
  143. if !ok {
  144. return nil, fmt.Errorf("invoke request missing request_id: %s", request)
  145. }
  146. // get request
  147. detailed_request, ok := request["request"].(map[string]any)
  148. if !ok {
  149. return nil, fmt.Errorf("invoke request missing request: %s", request)
  150. }
  151. return NewBackwardsInvocation(
  152. BackwardsInvocationType(typ),
  153. backwards_request_id,
  154. session,
  155. writer,
  156. detailed_request,
  157. ), nil
  158. }
  159. var (
  160. dispatchMapping = map[dify_invocation.InvokeType]func(handle *BackwardsInvocation){
  161. dify_invocation.INVOKE_TYPE_TOOL: func(handle *BackwardsInvocation) {
  162. genericDispatchTask(handle, executeDifyInvocationToolTask)
  163. },
  164. dify_invocation.INVOKE_TYPE_LLM: func(handle *BackwardsInvocation) {
  165. genericDispatchTask(handle, executeDifyInvocationLLMTask)
  166. },
  167. dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: func(handle *BackwardsInvocation) {
  168. genericDispatchTask(handle, executeDifyInvocationTextEmbeddingTask)
  169. },
  170. dify_invocation.INVOKE_TYPE_RERANK: func(handle *BackwardsInvocation) {
  171. genericDispatchTask(handle, executeDifyInvocationRerankTask)
  172. },
  173. dify_invocation.INVOKE_TYPE_TTS: func(handle *BackwardsInvocation) {
  174. genericDispatchTask(handle, executeDifyInvocationTTSTask)
  175. },
  176. dify_invocation.INVOKE_TYPE_SPEECH2TEXT: func(handle *BackwardsInvocation) {
  177. genericDispatchTask(handle, executeDifyInvocationSpeech2TextTask)
  178. },
  179. dify_invocation.INVOKE_TYPE_MODERATION: func(handle *BackwardsInvocation) {
  180. genericDispatchTask(handle, executeDifyInvocationModerationTask)
  181. },
  182. dify_invocation.INVOKE_TYPE_APP: func(handle *BackwardsInvocation) {
  183. genericDispatchTask(handle, executeDifyInvocationAppTask)
  184. },
  185. dify_invocation.INVOKE_TYPE_STORAGE: func(handle *BackwardsInvocation) {
  186. genericDispatchTask(handle, executeDifyInvocationStorageTask)
  187. },
  188. }
  189. )
  190. func genericDispatchTask[T any](
  191. handle *BackwardsInvocation,
  192. dispatch func(
  193. handle *BackwardsInvocation,
  194. request *T,
  195. ),
  196. ) {
  197. r, err := parser.MapToStruct[T](handle.RequestData())
  198. if err != nil {
  199. handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
  200. return
  201. }
  202. dispatch(handle, r)
  203. }
  204. func dispatchDifyInvocationTask(handle *BackwardsInvocation) {
  205. request_data := handle.RequestData()
  206. tenant_id, err := handle.TenantID()
  207. if err != nil {
  208. handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
  209. return
  210. }
  211. request_data["tenant_id"] = tenant_id
  212. user_id, err := handle.UserID()
  213. if err != nil {
  214. handle.WriteError(fmt.Errorf("get user id failed: %s", err.Error()))
  215. return
  216. }
  217. request_data["user_id"] = user_id
  218. typ := handle.Type()
  219. request_data["type"] = typ
  220. for t, v := range dispatchMapping {
  221. if t == handle.Type() {
  222. v(handle)
  223. return
  224. }
  225. }
  226. handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
  227. }
  228. func executeDifyInvocationToolTask(
  229. handle *BackwardsInvocation,
  230. request *dify_invocation.InvokeToolRequest,
  231. ) {
  232. response, err := dify_invocation.InvokeTool(request)
  233. if err != nil {
  234. handle.WriteError(fmt.Errorf("invoke tool failed: %s", err.Error()))
  235. return
  236. }
  237. response.Async(func(t tool_entities.ToolResponseChunk) {
  238. handle.WriteResponse("stream", t)
  239. })
  240. }
  241. func executeDifyInvocationLLMTask(
  242. handle *BackwardsInvocation,
  243. request *dify_invocation.InvokeLLMRequest,
  244. ) {
  245. response, err := dify_invocation.InvokeLLM(request)
  246. if err != nil {
  247. handle.WriteError(fmt.Errorf("invoke llm model failed: %s", err.Error()))
  248. return
  249. }
  250. response.Async(func(t model_entities.LLMResultChunk) {
  251. handle.WriteResponse("stream", t)
  252. })
  253. }
  254. func executeDifyInvocationTextEmbeddingTask(
  255. handle *BackwardsInvocation,
  256. request *dify_invocation.InvokeTextEmbeddingRequest,
  257. ) {
  258. response, err := dify_invocation.InvokeTextEmbedding(request)
  259. if err != nil {
  260. handle.WriteError(fmt.Errorf("invoke text-embedding model failed: %s", err.Error()))
  261. return
  262. }
  263. handle.WriteResponse("struct", response)
  264. }
  265. func executeDifyInvocationRerankTask(
  266. handle *BackwardsInvocation,
  267. request *dify_invocation.InvokeRerankRequest,
  268. ) {
  269. response, err := dify_invocation.InvokeRerank(request)
  270. if err != nil {
  271. handle.WriteError(fmt.Errorf("invoke rerank model failed: %s", err.Error()))
  272. return
  273. }
  274. handle.WriteResponse("struct", response)
  275. }
  276. func executeDifyInvocationTTSTask(
  277. handle *BackwardsInvocation,
  278. request *dify_invocation.InvokeTTSRequest,
  279. ) {
  280. response, err := dify_invocation.InvokeTTS(request)
  281. if err != nil {
  282. handle.WriteError(fmt.Errorf("invoke tts model failed: %s", err.Error()))
  283. return
  284. }
  285. response.Async(func(t model_entities.TTSResult) {
  286. handle.WriteResponse("struct", t)
  287. })
  288. }
  289. func executeDifyInvocationSpeech2TextTask(
  290. handle *BackwardsInvocation,
  291. request *dify_invocation.InvokeSpeech2TextRequest,
  292. ) {
  293. response, err := dify_invocation.InvokeSpeech2Text(request)
  294. if err != nil {
  295. handle.WriteError(fmt.Errorf("invoke speech2text model failed: %s", err.Error()))
  296. return
  297. }
  298. handle.WriteResponse("struct", response)
  299. }
  300. func executeDifyInvocationModerationTask(
  301. handle *BackwardsInvocation,
  302. request *dify_invocation.InvokeModerationRequest,
  303. ) {
  304. response, err := dify_invocation.InvokeModeration(request)
  305. if err != nil {
  306. handle.WriteError(fmt.Errorf("invoke moderation model failed: %s", err.Error()))
  307. return
  308. }
  309. handle.WriteResponse("struct", response)
  310. }
  311. func executeDifyInvocationAppTask(
  312. handle *BackwardsInvocation,
  313. request *dify_invocation.InvokeAppRequest,
  314. ) {
  315. response, err := dify_invocation.InvokeApp(request)
  316. if err != nil {
  317. handle.WriteError(fmt.Errorf("invoke app failed: %s", err.Error()))
  318. return
  319. }
  320. user_id, err := handle.UserID()
  321. if err != nil {
  322. handle.WriteError(fmt.Errorf("get user id failed: %s", err.Error()))
  323. return
  324. }
  325. request.User = user_id
  326. response.Async(func(t map[string]any) {
  327. handle.WriteResponse("stream", t)
  328. })
  329. }
  330. func executeDifyInvocationStorageTask(
  331. handle *BackwardsInvocation,
  332. request *dify_invocation.InvokeStorageRequest,
  333. ) {
  334. if handle.session == nil {
  335. handle.WriteError(fmt.Errorf("session not found"))
  336. return
  337. }
  338. persistence := persistence.GetPersistence()
  339. if persistence == nil {
  340. handle.WriteError(fmt.Errorf("persistence not found"))
  341. return
  342. }
  343. tenant_id, err := handle.TenantID()
  344. if err != nil {
  345. handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
  346. return
  347. }
  348. plugin_id := handle.session.PluginUniqueIdentifier
  349. if request.Opt == dify_invocation.STORAGE_OPT_GET {
  350. data, err := persistence.Load(tenant_id, plugin_id.PluginID(), request.Key)
  351. if err != nil {
  352. handle.WriteError(fmt.Errorf("load data failed: %s", err.Error()))
  353. return
  354. }
  355. handle.WriteResponse("struct", map[string]any{
  356. "data": hex.EncodeToString(data),
  357. })
  358. } else if request.Opt == dify_invocation.STORAGE_OPT_SET {
  359. data, err := hex.DecodeString(request.Value)
  360. if err != nil {
  361. handle.WriteError(fmt.Errorf("decode data failed: %s", err.Error()))
  362. return
  363. }
  364. if err := persistence.Save(tenant_id, plugin_id.PluginID(), request.Key, data); err != nil {
  365. handle.WriteError(fmt.Errorf("save data failed: %s", err.Error()))
  366. return
  367. }
  368. handle.WriteResponse("struct", map[string]any{
  369. "data": "ok",
  370. })
  371. } else if request.Opt == dify_invocation.STORAGE_OPT_DEL {
  372. if err := persistence.Delete(tenant_id, plugin_id.PluginID(), request.Key); err != nil {
  373. handle.WriteError(fmt.Errorf("delete data failed: %s", err.Error()))
  374. return
  375. }
  376. handle.WriteResponse("struct", map[string]any{
  377. "data": "ok",
  378. })
  379. }
  380. }