install_plugin.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. package service
  2. import (
  3. "fmt"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
  6. "github.com/langgenius/dify-plugin-daemon/internal/db"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/app"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/models"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/models/curd"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache/helper"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  15. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  16. "gorm.io/gorm"
  17. )
  18. type InstallPluginResponse struct {
  19. AllInstalled bool `json:"all_installed"`
  20. TaskID string `json:"task_id"`
  21. }
  22. type InstallPluginOnDoneHandler func(
  23. pluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier,
  24. declaration *plugin_entities.PluginDeclaration,
  25. ) error
  26. func InstallPluginRuntimeToTenant(
  27. config *app.Config,
  28. tenant_id string,
  29. plugin_unique_identifiers []plugin_entities.PluginUniqueIdentifier,
  30. source string,
  31. meta map[string]any,
  32. onDone InstallPluginOnDoneHandler, // since installing plugin is a async task, we need to call it asynchronously
  33. ) (*InstallPluginResponse, error) {
  34. response := &InstallPluginResponse{}
  35. pluginsWaitForInstallation := []plugin_entities.PluginUniqueIdentifier{}
  36. runtimeType := plugin_entities.PluginRuntimeType("")
  37. if config.Platform == app.PLATFORM_AWS_LAMBDA {
  38. runtimeType = plugin_entities.PLUGIN_RUNTIME_TYPE_AWS
  39. } else if config.Platform == app.PLATFORM_LOCAL {
  40. runtimeType = plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL
  41. } else {
  42. return nil, fmt.Errorf("unsupported platform: %s", config.Platform)
  43. }
  44. task := &models.InstallTask{
  45. Status: models.InstallTaskStatusRunning,
  46. TenantID: tenant_id,
  47. TotalPlugins: len(plugin_unique_identifiers),
  48. CompletedPlugins: 0,
  49. Plugins: []models.InstallTaskPluginStatus{},
  50. }
  51. for i, pluginUniqueIdentifier := range plugin_unique_identifiers {
  52. // fetch plugin declaration first, before installing, we need to ensure pkg is uploaded
  53. pluginDeclaration, err := helper.CombinedGetPluginDeclaration(
  54. pluginUniqueIdentifier,
  55. tenant_id,
  56. runtimeType,
  57. )
  58. if err != nil {
  59. return nil, err
  60. }
  61. // check if plugin is already installed
  62. _, err = db.GetOne[models.Plugin](
  63. db.Equal("plugin_unique_identifier", pluginUniqueIdentifier.String()),
  64. )
  65. task.Plugins = append(task.Plugins, models.InstallTaskPluginStatus{
  66. PluginUniqueIdentifier: pluginUniqueIdentifier,
  67. PluginID: pluginUniqueIdentifier.PluginID(),
  68. Status: models.InstallTaskStatusPending,
  69. Icon: pluginDeclaration.Icon,
  70. Labels: pluginDeclaration.Label,
  71. Message: "",
  72. })
  73. if err == nil {
  74. task.CompletedPlugins++
  75. task.Plugins[i].Status = models.InstallTaskStatusSuccess
  76. task.Plugins[i].Message = "Installed"
  77. onDone(pluginUniqueIdentifier, pluginDeclaration)
  78. continue
  79. }
  80. if err != db.ErrDatabaseNotFound {
  81. return nil, err
  82. }
  83. pluginsWaitForInstallation = append(pluginsWaitForInstallation, pluginUniqueIdentifier)
  84. }
  85. if len(pluginsWaitForInstallation) == 0 {
  86. response.AllInstalled = true
  87. response.TaskID = ""
  88. return response, nil
  89. }
  90. err := db.Create(task)
  91. if err != nil {
  92. return nil, err
  93. }
  94. response.TaskID = task.ID
  95. manager := plugin_manager.Manager()
  96. tasks := []func(){}
  97. for _, pluginUniqueIdentifier := range pluginsWaitForInstallation {
  98. // copy the variable to avoid race condition
  99. pluginUniqueIdentifier := pluginUniqueIdentifier
  100. declaration, err := helper.CombinedGetPluginDeclaration(
  101. pluginUniqueIdentifier,
  102. tenant_id,
  103. runtimeType,
  104. )
  105. if err != nil {
  106. return nil, err
  107. }
  108. tasks = append(tasks, func() {
  109. updateTaskStatus := func(modifier func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus)) {
  110. if err := db.WithTransaction(func(tx *gorm.DB) error {
  111. task, err := db.GetOne[models.InstallTask](
  112. db.WithTransactionContext(tx),
  113. db.Equal("id", task.ID),
  114. db.WLock(), // write lock, multiple tasks can't update the same task
  115. )
  116. if err == db.ErrDatabaseNotFound {
  117. return nil
  118. }
  119. if err != nil {
  120. return err
  121. }
  122. taskPointer := &task
  123. var pluginStatus *models.InstallTaskPluginStatus
  124. for i := range task.Plugins {
  125. if task.Plugins[i].PluginUniqueIdentifier == pluginUniqueIdentifier {
  126. pluginStatus = &task.Plugins[i]
  127. break
  128. }
  129. }
  130. if pluginStatus == nil {
  131. return nil
  132. }
  133. modifier(taskPointer, pluginStatus)
  134. return db.Update(taskPointer, tx)
  135. }); err != nil {
  136. log.Error("failed to update install task status %s", err.Error())
  137. }
  138. }
  139. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  140. plugin.Status = models.InstallTaskStatusRunning
  141. plugin.Message = "Installing"
  142. })
  143. var stream *stream.Stream[plugin_manager.PluginInstallResponse]
  144. if config.Platform == app.PLATFORM_AWS_LAMBDA {
  145. var zipDecoder *decoder.ZipPluginDecoder
  146. var pkgFile []byte
  147. pkgFile, err = manager.GetPackage(pluginUniqueIdentifier)
  148. if err != nil {
  149. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  150. task.Status = models.InstallTaskStatusFailed
  151. plugin.Status = models.InstallTaskStatusFailed
  152. plugin.Message = "Failed to read plugin package"
  153. })
  154. return
  155. }
  156. zipDecoder, err = decoder.NewZipPluginDecoder(pkgFile)
  157. if err != nil {
  158. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  159. task.Status = models.InstallTaskStatusFailed
  160. plugin.Status = models.InstallTaskStatusFailed
  161. plugin.Message = err.Error()
  162. })
  163. return
  164. }
  165. stream, err = manager.InstallToAWSFromPkg(zipDecoder, source, meta)
  166. } else if config.Platform == app.PLATFORM_LOCAL {
  167. stream, err = manager.InstallToLocal(pluginUniqueIdentifier, source, meta)
  168. } else {
  169. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  170. task.Status = models.InstallTaskStatusFailed
  171. plugin.Status = models.InstallTaskStatusFailed
  172. plugin.Message = "Unsupported platform"
  173. })
  174. return
  175. }
  176. if err != nil {
  177. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  178. task.Status = models.InstallTaskStatusFailed
  179. plugin.Status = models.InstallTaskStatusFailed
  180. plugin.Message = err.Error()
  181. })
  182. return
  183. }
  184. for stream.Next() {
  185. message, err := stream.Read()
  186. if err != nil {
  187. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  188. task.Status = models.InstallTaskStatusFailed
  189. plugin.Status = models.InstallTaskStatusFailed
  190. plugin.Message = err.Error()
  191. })
  192. return
  193. }
  194. if message.Event == plugin_manager.PluginInstallEventError {
  195. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  196. task.Status = models.InstallTaskStatusFailed
  197. plugin.Status = models.InstallTaskStatusFailed
  198. plugin.Message = message.Data
  199. })
  200. return
  201. }
  202. if message.Event == plugin_manager.PluginInstallEventDone {
  203. if err := onDone(pluginUniqueIdentifier, declaration); err != nil {
  204. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  205. task.Status = models.InstallTaskStatusFailed
  206. plugin.Status = models.InstallTaskStatusFailed
  207. plugin.Message = "Failed to create plugin"
  208. })
  209. return
  210. }
  211. }
  212. }
  213. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  214. plugin.Status = models.InstallTaskStatusSuccess
  215. plugin.Message = "Installed"
  216. task.CompletedPlugins++
  217. // check if all plugins are installed
  218. if task.CompletedPlugins == task.TotalPlugins {
  219. task.Status = models.InstallTaskStatusSuccess
  220. }
  221. })
  222. })
  223. }
  224. // submit async tasks
  225. routine.WithMaxRoutine(3, tasks)
  226. return response, nil
  227. }
  228. func InstallPluginFromIdentifiers(
  229. config *app.Config,
  230. tenant_id string,
  231. plugin_unique_identifiers []plugin_entities.PluginUniqueIdentifier,
  232. source string,
  233. meta map[string]any,
  234. ) *entities.Response {
  235. response, err := InstallPluginRuntimeToTenant(
  236. config,
  237. tenant_id,
  238. plugin_unique_identifiers,
  239. source,
  240. meta,
  241. func(
  242. pluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier,
  243. declaration *plugin_entities.PluginDeclaration,
  244. ) error {
  245. runtimeType := plugin_entities.PluginRuntimeType("")
  246. switch config.Platform {
  247. case app.PLATFORM_AWS_LAMBDA:
  248. runtimeType = plugin_entities.PLUGIN_RUNTIME_TYPE_AWS
  249. case app.PLATFORM_LOCAL:
  250. runtimeType = plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL
  251. default:
  252. return fmt.Errorf("unsupported platform: %s", config.Platform)
  253. }
  254. _, _, err := curd.InstallPlugin(
  255. tenant_id,
  256. pluginUniqueIdentifier,
  257. runtimeType,
  258. declaration,
  259. source,
  260. meta,
  261. )
  262. return err
  263. })
  264. if err != nil {
  265. return entities.NewErrorResponse(-500, err.Error())
  266. }
  267. return entities.NewSuccessResponse(response)
  268. }
  269. func UpgradePlugin(
  270. config *app.Config,
  271. tenant_id string,
  272. source string,
  273. meta map[string]any,
  274. original_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  275. new_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  276. ) *entities.Response {
  277. if original_plugin_unique_identifier == new_plugin_unique_identifier {
  278. return entities.NewErrorResponse(-400, "original and new plugin unique identifier are the same")
  279. }
  280. if original_plugin_unique_identifier.PluginID() != new_plugin_unique_identifier.PluginID() {
  281. return entities.NewErrorResponse(-400, "original and new plugin id are different")
  282. }
  283. // uninstall the original plugin
  284. installation, err := db.GetOne[models.PluginInstallation](
  285. db.Equal("tenant_id", tenant_id),
  286. db.Equal("plugin_unique_identifier", original_plugin_unique_identifier.String()),
  287. db.Equal("source", source),
  288. )
  289. if err == db.ErrDatabaseNotFound {
  290. return entities.NewErrorResponse(-404, "Plugin installation not found for this tenant")
  291. }
  292. if err != nil {
  293. return entities.NewErrorResponse(-500, err.Error())
  294. }
  295. // install the new plugin runtime
  296. response, err := InstallPluginRuntimeToTenant(
  297. config,
  298. tenant_id,
  299. []plugin_entities.PluginUniqueIdentifier{new_plugin_unique_identifier},
  300. source,
  301. meta,
  302. func(
  303. pluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier,
  304. declaration *plugin_entities.PluginDeclaration,
  305. ) error {
  306. // uninstall the original plugin
  307. upgradeResponse, err := curd.UpgradePlugin(
  308. tenant_id,
  309. original_plugin_unique_identifier,
  310. new_plugin_unique_identifier,
  311. declaration,
  312. plugin_entities.PluginRuntimeType(installation.RuntimeType),
  313. source,
  314. meta,
  315. )
  316. if err != nil {
  317. return err
  318. }
  319. if upgradeResponse.IsOriginalPluginDeleted {
  320. // delete the plugin if no installation left
  321. manager := plugin_manager.Manager()
  322. if string(upgradeResponse.DeletedPlugin.InstallType) == string(
  323. plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL,
  324. ) {
  325. err = manager.UninstallFromLocal(
  326. plugin_entities.PluginUniqueIdentifier(upgradeResponse.DeletedPlugin.PluginUniqueIdentifier),
  327. )
  328. if err != nil {
  329. return err
  330. }
  331. }
  332. }
  333. return nil
  334. },
  335. )
  336. if err != nil {
  337. return entities.NewErrorResponse(-500, err.Error())
  338. }
  339. return entities.NewSuccessResponse(response)
  340. }
  341. func FetchPluginInstallationTasks(
  342. tenant_id string,
  343. page int,
  344. page_size int,
  345. ) *entities.Response {
  346. tasks, err := db.GetAll[models.InstallTask](
  347. db.Equal("tenant_id", tenant_id),
  348. db.OrderBy("created_at", true),
  349. db.Page(page, page_size),
  350. )
  351. if err != nil {
  352. return entities.NewErrorResponse(-500, err.Error())
  353. }
  354. return entities.NewSuccessResponse(tasks)
  355. }
  356. func FetchPluginInstallationTask(
  357. tenant_id string,
  358. task_id string,
  359. ) *entities.Response {
  360. task, err := db.GetOne[models.InstallTask](
  361. db.Equal("id", task_id),
  362. db.Equal("tenant_id", tenant_id),
  363. )
  364. if err != nil {
  365. return entities.NewErrorResponse(-500, err.Error())
  366. }
  367. return entities.NewSuccessResponse(task)
  368. }
  369. func DeletePluginInstallationTask(
  370. tenant_id string,
  371. task_id string,
  372. ) *entities.Response {
  373. err := db.DeleteByCondition(
  374. models.InstallTask{
  375. Model: models.Model{
  376. ID: task_id,
  377. },
  378. TenantID: tenant_id,
  379. },
  380. )
  381. if err != nil {
  382. return entities.NewErrorResponse(-500, err.Error())
  383. }
  384. return entities.NewSuccessResponse(true)
  385. }
  386. func DeletePluginInstallationItemFromTask(
  387. tenant_id string,
  388. task_id string,
  389. identifier plugin_entities.PluginUniqueIdentifier,
  390. ) *entities.Response {
  391. item, err := db.GetOne[models.InstallTask](
  392. db.Equal("task_id", task_id),
  393. db.Equal("tenant_id", tenant_id),
  394. )
  395. if err != nil {
  396. return entities.NewErrorResponse(-500, err.Error())
  397. }
  398. plugins := []models.InstallTaskPluginStatus{}
  399. for _, plugin := range item.Plugins {
  400. if plugin.PluginUniqueIdentifier != identifier {
  401. plugins = append(plugins, plugin)
  402. }
  403. }
  404. if len(plugins) == 0 {
  405. err = db.Delete(&item)
  406. } else {
  407. item.Plugins = plugins
  408. err = db.Update(&item)
  409. }
  410. if err != nil {
  411. return entities.NewErrorResponse(-500, err.Error())
  412. }
  413. return entities.NewSuccessResponse(true)
  414. }
  415. func FetchPluginFromIdentifier(
  416. pluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier,
  417. ) *entities.Response {
  418. _, err := db.GetOne[models.Plugin](
  419. db.Equal("plugin_unique_identifier", pluginUniqueIdentifier.String()),
  420. )
  421. if err == db.ErrDatabaseNotFound {
  422. return entities.NewSuccessResponse(false)
  423. }
  424. if err != nil {
  425. return entities.NewErrorResponse(-500, err.Error())
  426. }
  427. return entities.NewSuccessResponse(true)
  428. }
  429. func UninstallPlugin(
  430. tenant_id string,
  431. plugin_installation_id string,
  432. ) *entities.Response {
  433. // Check if the plugin exists for the tenant
  434. installation, err := db.GetOne[models.PluginInstallation](
  435. db.Equal("tenant_id", tenant_id),
  436. db.Equal("id", plugin_installation_id),
  437. )
  438. if err == db.ErrDatabaseNotFound {
  439. return entities.NewErrorResponse(-404, "Plugin installation not found for this tenant")
  440. }
  441. if err != nil {
  442. return entities.NewErrorResponse(-500, err.Error())
  443. }
  444. pluginUniqueIdentifier, err := plugin_entities.NewPluginUniqueIdentifier(installation.PluginUniqueIdentifier)
  445. if err != nil {
  446. return entities.NewErrorResponse(-500, fmt.Sprintf("failed to parse plugin unique identifier: %v", err))
  447. }
  448. // Uninstall the plugin
  449. deleteResponse, err := curd.UninstallPlugin(
  450. tenant_id,
  451. pluginUniqueIdentifier,
  452. installation.ID,
  453. )
  454. if err != nil {
  455. return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
  456. }
  457. if deleteResponse.IsPluginDeleted {
  458. // delete the plugin if no installation left
  459. manager := plugin_manager.Manager()
  460. if deleteResponse.Installation.RuntimeType == string(
  461. plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL,
  462. ) {
  463. err = manager.UninstallFromLocal(pluginUniqueIdentifier)
  464. if err != nil {
  465. return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
  466. }
  467. }
  468. }
  469. return entities.NewSuccessResponse(true)
  470. }