install_plugin.go 15 KB


  1. package service
  2. import (
  3. "fmt"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
  6. "github.com/langgenius/dify-plugin-daemon/internal/db"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/app"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/models"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/models/curd"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache/helper"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  15. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  16. "gorm.io/gorm"
  17. )
  18. type InstallPluginResponse struct {
  19. AllInstalled bool `json:"all_installed"`
  20. TaskID string `json:"task_id"`
  21. }
  22. type InstallPluginOnDoneHandler func(
  23. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  24. declaration *plugin_entities.PluginDeclaration,
  25. ) error
  26. func InstallPluginRuntimeToTenant(
  27. config *app.Config,
  28. tenant_id string,
  29. plugin_unique_identifiers []plugin_entities.PluginUniqueIdentifier,
  30. source string,
  31. meta map[string]any,
  32. onDone InstallPluginOnDoneHandler, // since installing plugin is a async task, we need to call it asynchronously
  33. ) (*InstallPluginResponse, error) {
  34. response := &InstallPluginResponse{}
  35. pluginsWaitForInstallation := []plugin_entities.PluginUniqueIdentifier{}
  36. task := &models.InstallTask{
  37. Status: models.InstallTaskStatusRunning,
  38. TenantID: tenant_id,
  39. TotalPlugins: len(plugin_unique_identifiers),
  40. CompletedPlugins: 0,
  41. Plugins: []models.InstallTaskPluginStatus{},
  42. }
  43. for i, plugin_unique_identifier := range plugin_unique_identifiers {
  44. // fetch plugin declaration first, before installing, we need to ensure pkg is uploaded
  45. plugin_declaration, err := helper.CombinedGetPluginDeclaration(plugin_unique_identifier)
  46. if err != nil {
  47. return nil, err
  48. }
  49. // check if plugin is already installed
  50. _, err = db.GetOne[models.Plugin](
  51. db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
  52. )
  53. task.Plugins = append(task.Plugins, models.InstallTaskPluginStatus{
  54. PluginUniqueIdentifier: plugin_unique_identifier,
  55. PluginID: plugin_unique_identifier.PluginID(),
  56. Status: models.InstallTaskStatusPending,
  57. Icon: plugin_declaration.Icon,
  58. Labels: plugin_declaration.Label,
  59. Message: "",
  60. })
  61. if err == nil {
  62. task.CompletedPlugins++
  63. task.Plugins[i].Status = models.InstallTaskStatusSuccess
  64. task.Plugins[i].Message = "Installed"
  65. onDone(plugin_unique_identifier, plugin_declaration)
  66. continue
  67. }
  68. if err != db.ErrDatabaseNotFound {
  69. return nil, err
  70. }
  71. pluginsWaitForInstallation = append(pluginsWaitForInstallation, plugin_unique_identifier)
  72. }
  73. if len(pluginsWaitForInstallation) == 0 {
  74. response.AllInstalled = true
  75. response.TaskID = ""
  76. return response, nil
  77. }
  78. err := db.Create(task)
  79. if err != nil {
  80. return nil, err
  81. }
  82. response.TaskID = task.ID
  83. manager := plugin_manager.Manager()
  84. tasks := []func(){}
  85. for _, plugin_unique_identifier := range pluginsWaitForInstallation {
  86. // copy the variable to avoid race condition
  87. plugin_unique_identifier := plugin_unique_identifier
  88. declaration, err := manager.GetDeclaration(plugin_unique_identifier)
  89. if err != nil {
  90. return nil, err
  91. }
  92. tasks = append(tasks, func() {
  93. updateTaskStatus := func(modifier func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus)) {
  94. if err := db.WithTransaction(func(tx *gorm.DB) error {
  95. task, err := db.GetOne[models.InstallTask](
  96. db.WithTransactionContext(tx),
  97. db.Equal("id", task.ID),
  98. db.WLock(), // write lock, multiple tasks can't update the same task
  99. )
  100. if err == db.ErrDatabaseNotFound {
  101. return nil
  102. }
  103. if err != nil {
  104. return err
  105. }
  106. task_pointer := &task
  107. var plugin_status *models.InstallTaskPluginStatus
  108. for i := range task.Plugins {
  109. if task.Plugins[i].PluginUniqueIdentifier == plugin_unique_identifier {
  110. plugin_status = &task.Plugins[i]
  111. break
  112. }
  113. }
  114. if plugin_status == nil {
  115. return nil
  116. }
  117. modifier(task_pointer, plugin_status)
  118. return db.Update(task_pointer, tx)
  119. }); err != nil {
  120. log.Error("failed to update install task status %s", err.Error())
  121. }
  122. }
  123. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  124. plugin.Status = models.InstallTaskStatusRunning
  125. plugin.Message = "Installing"
  126. })
  127. var stream *stream.Stream[plugin_manager.PluginInstallResponse]
  128. if config.Platform == app.PLATFORM_AWS_LAMBDA {
  129. var zip_decoder *decoder.ZipPluginDecoder
  130. var pkg_file []byte
  131. pkg_file, err = manager.GetPackage(plugin_unique_identifier)
  132. if err != nil {
  133. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  134. task.Status = models.InstallTaskStatusFailed
  135. plugin.Status = models.InstallTaskStatusFailed
  136. plugin.Message = "Failed to read plugin package"
  137. })
  138. return
  139. }
  140. zip_decoder, err = decoder.NewZipPluginDecoder(pkg_file)
  141. if err != nil {
  142. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  143. task.Status = models.InstallTaskStatusFailed
  144. plugin.Status = models.InstallTaskStatusFailed
  145. plugin.Message = err.Error()
  146. })
  147. return
  148. }
  149. stream, err = manager.InstallToAWSFromPkg(zip_decoder, source, meta)
  150. } else if config.Platform == app.PLATFORM_LOCAL {
  151. stream, err = manager.InstallToLocal(plugin_unique_identifier, source, meta)
  152. } else {
  153. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  154. task.Status = models.InstallTaskStatusFailed
  155. plugin.Status = models.InstallTaskStatusFailed
  156. plugin.Message = "Unsupported platform"
  157. })
  158. return
  159. }
  160. if err != nil {
  161. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  162. task.Status = models.InstallTaskStatusFailed
  163. plugin.Status = models.InstallTaskStatusFailed
  164. plugin.Message = err.Error()
  165. })
  166. return
  167. }
  168. for stream.Next() {
  169. message, err := stream.Read()
  170. if err != nil {
  171. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  172. task.Status = models.InstallTaskStatusFailed
  173. plugin.Status = models.InstallTaskStatusFailed
  174. plugin.Message = err.Error()
  175. })
  176. return
  177. }
  178. if message.Event == plugin_manager.PluginInstallEventError {
  179. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  180. task.Status = models.InstallTaskStatusFailed
  181. plugin.Status = models.InstallTaskStatusFailed
  182. plugin.Message = message.Data
  183. })
  184. return
  185. }
  186. if message.Event == plugin_manager.PluginInstallEventDone {
  187. if err := onDone(plugin_unique_identifier, declaration); err != nil {
  188. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  189. task.Status = models.InstallTaskStatusFailed
  190. plugin.Status = models.InstallTaskStatusFailed
  191. plugin.Message = "Failed to create plugin"
  192. })
  193. return
  194. }
  195. }
  196. }
  197. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  198. plugin.Status = models.InstallTaskStatusSuccess
  199. plugin.Message = "Installed"
  200. task.CompletedPlugins++
  201. // check if all plugins are installed
  202. if task.CompletedPlugins == task.TotalPlugins {
  203. task.Status = models.InstallTaskStatusSuccess
  204. }
  205. })
  206. })
  207. }
  208. // submit async tasks
  209. routine.WithMaxRoutine(3, tasks)
  210. return response, nil
  211. }
  212. func InstallPluginFromIdentifiers(
  213. config *app.Config,
  214. tenant_id string,
  215. plugin_unique_identifiers []plugin_entities.PluginUniqueIdentifier,
  216. source string,
  217. meta map[string]any,
  218. ) *entities.Response {
  219. response, err := InstallPluginRuntimeToTenant(
  220. config,
  221. tenant_id,
  222. plugin_unique_identifiers,
  223. source,
  224. meta,
  225. func(
  226. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  227. declaration *plugin_entities.PluginDeclaration,
  228. ) error {
  229. runtime_type := plugin_entities.PluginRuntimeType("")
  230. switch config.Platform {
  231. case app.PLATFORM_AWS_LAMBDA:
  232. runtime_type = plugin_entities.PLUGIN_RUNTIME_TYPE_AWS
  233. case app.PLATFORM_LOCAL:
  234. runtime_type = plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL
  235. default:
  236. return fmt.Errorf("unsupported platform: %s", config.Platform)
  237. }
  238. _, _, err := curd.InstallPlugin(
  239. tenant_id,
  240. plugin_unique_identifier,
  241. runtime_type,
  242. declaration,
  243. source,
  244. meta,
  245. )
  246. return err
  247. })
  248. if err != nil {
  249. return entities.NewErrorResponse(-500, err.Error())
  250. }
  251. return entities.NewSuccessResponse(response)
  252. }
  253. func UpgradePlugin(
  254. config *app.Config,
  255. tenant_id string,
  256. source string,
  257. meta map[string]any,
  258. original_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  259. new_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  260. ) *entities.Response {
  261. if original_plugin_unique_identifier == new_plugin_unique_identifier {
  262. return entities.NewErrorResponse(-400, "original and new plugin unique identifier are the same")
  263. }
  264. if original_plugin_unique_identifier.PluginID() != new_plugin_unique_identifier.PluginID() {
  265. return entities.NewErrorResponse(-400, "original and new plugin id are different")
  266. }
  267. // uninstall the original plugin
  268. installation, err := db.GetOne[models.PluginInstallation](
  269. db.Equal("tenant_id", tenant_id),
  270. db.Equal("plugin_unique_identifier", original_plugin_unique_identifier.String()),
  271. db.Equal("source", source),
  272. )
  273. if err == db.ErrDatabaseNotFound {
  274. return entities.NewErrorResponse(-404, "Plugin installation not found for this tenant")
  275. }
  276. if err != nil {
  277. return entities.NewErrorResponse(-500, err.Error())
  278. }
  279. // install the new plugin runtime
  280. response, err := InstallPluginRuntimeToTenant(
  281. config,
  282. tenant_id,
  283. []plugin_entities.PluginUniqueIdentifier{new_plugin_unique_identifier},
  284. source,
  285. meta,
  286. func(
  287. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  288. declaration *plugin_entities.PluginDeclaration,
  289. ) error {
  290. // uninstall the original plugin
  291. upgrade_response, err := curd.UpgradePlugin(
  292. tenant_id,
  293. original_plugin_unique_identifier,
  294. new_plugin_unique_identifier,
  295. declaration,
  296. plugin_entities.PluginRuntimeType(installation.RuntimeType),
  297. source,
  298. meta,
  299. )
  300. if err != nil {
  301. return err
  302. }
  303. if upgrade_response.IsOriginalPluginDeleted {
  304. // delete the plugin if no installation left
  305. manager := plugin_manager.Manager()
  306. if string(upgrade_response.DeletedPlugin.InstallType) == string(
  307. plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL,
  308. ) {
  309. err = manager.UninstallFromLocal(
  310. plugin_entities.PluginUniqueIdentifier(upgrade_response.DeletedPlugin.PluginUniqueIdentifier),
  311. )
  312. if err != nil {
  313. return err
  314. }
  315. }
  316. }
  317. return nil
  318. },
  319. )
  320. if err != nil {
  321. return entities.NewErrorResponse(-500, err.Error())
  322. }
  323. return entities.NewSuccessResponse(response)
  324. }
  325. func FetchPluginInstallationTasks(
  326. tenant_id string,
  327. page int,
  328. page_size int,
  329. ) *entities.Response {
  330. tasks, err := db.GetAll[models.InstallTask](
  331. db.Equal("tenant_id", tenant_id),
  332. db.OrderBy("created_at", true),
  333. db.Page(page, page_size),
  334. )
  335. if err != nil {
  336. return entities.NewErrorResponse(-500, err.Error())
  337. }
  338. return entities.NewSuccessResponse(tasks)
  339. }
  340. func FetchPluginInstallationTask(
  341. tenant_id string,
  342. task_id string,
  343. ) *entities.Response {
  344. task, err := db.GetOne[models.InstallTask](
  345. db.Equal("id", task_id),
  346. db.Equal("tenant_id", tenant_id),
  347. )
  348. if err != nil {
  349. return entities.NewErrorResponse(-500, err.Error())
  350. }
  351. return entities.NewSuccessResponse(task)
  352. }
  353. func DeletePluginInstallationTask(
  354. tenant_id string,
  355. task_id string,
  356. ) *entities.Response {
  357. err := db.DeleteByCondition(
  358. models.InstallTask{
  359. Model: models.Model{
  360. ID: task_id,
  361. },
  362. TenantID: tenant_id,
  363. },
  364. )
  365. if err != nil {
  366. return entities.NewErrorResponse(-500, err.Error())
  367. }
  368. return entities.NewSuccessResponse(true)
  369. }
  370. func DeletePluginInstallationItemFromTask(
  371. tenant_id string,
  372. task_id string,
  373. identifier plugin_entities.PluginUniqueIdentifier,
  374. ) *entities.Response {
  375. item, err := db.GetOne[models.InstallTask](
  376. db.Equal("task_id", task_id),
  377. db.Equal("tenant_id", tenant_id),
  378. )
  379. if err != nil {
  380. return entities.NewErrorResponse(-500, err.Error())
  381. }
  382. plugins := []models.InstallTaskPluginStatus{}
  383. for _, plugin := range item.Plugins {
  384. if plugin.PluginUniqueIdentifier != identifier {
  385. plugins = append(plugins, plugin)
  386. }
  387. }
  388. if len(plugins) == 0 {
  389. err = db.Delete(&item)
  390. } else {
  391. item.Plugins = plugins
  392. err = db.Update(&item)
  393. }
  394. if err != nil {
  395. return entities.NewErrorResponse(-500, err.Error())
  396. }
  397. return entities.NewSuccessResponse(true)
  398. }
  399. func FetchPluginFromIdentifier(
  400. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  401. ) *entities.Response {
  402. _, err := db.GetOne[models.Plugin](
  403. db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
  404. )
  405. if err == db.ErrDatabaseNotFound {
  406. return entities.NewSuccessResponse(false)
  407. }
  408. if err != nil {
  409. return entities.NewErrorResponse(-500, err.Error())
  410. }
  411. return entities.NewSuccessResponse(true)
  412. }
  413. func UninstallPlugin(
  414. tenant_id string,
  415. plugin_installation_id string,
  416. ) *entities.Response {
  417. // Check if the plugin exists for the tenant
  418. installation, err := db.GetOne[models.PluginInstallation](
  419. db.Equal("tenant_id", tenant_id),
  420. db.Equal("id", plugin_installation_id),
  421. )
  422. if err == db.ErrDatabaseNotFound {
  423. return entities.NewErrorResponse(-404, "Plugin installation not found for this tenant")
  424. }
  425. if err != nil {
  426. return entities.NewErrorResponse(-500, err.Error())
  427. }
  428. plugin_unique_identifier, err := plugin_entities.NewPluginUniqueIdentifier(installation.PluginUniqueIdentifier)
  429. if err != nil {
  430. return entities.NewErrorResponse(-500, fmt.Sprintf("failed to parse plugin unique identifier: %v", err))
  431. }
  432. // Uninstall the plugin
  433. delete_response, err := curd.UninstallPlugin(
  434. tenant_id,
  435. plugin_unique_identifier,
  436. installation.ID,
  437. )
  438. if err != nil {
  439. return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
  440. }
  441. if delete_response.IsPluginDeleted {
  442. // delete the plugin if no installation left
  443. manager := plugin_manager.Manager()
  444. if delete_response.Installation.RuntimeType == string(
  445. plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL,
  446. ) {
  447. err = manager.UninstallFromLocal(plugin_unique_identifier)
  448. if err != nil {
  449. return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
  450. }
  451. }
  452. }
  453. return entities.NewSuccessResponse(true)
  454. }