install_plugin.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. package service
  2. import (
  3. "fmt"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
  6. "github.com/langgenius/dify-plugin-daemon/internal/db"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/app"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/models"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/models/curd"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  15. "gorm.io/gorm"
  16. )
  17. type InstallPluginResponse struct {
  18. AllInstalled bool `json:"all_installed"`
  19. TaskID string `json:"task_id"`
  20. }
  21. type InstallPluginOnDoneHandler func(
  22. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  23. declaration *plugin_entities.PluginDeclaration,
  24. ) error
  25. func InstallPluginRuntimeToTenant(
  26. config *app.Config,
  27. tenant_id string,
  28. plugin_unique_identifiers []plugin_entities.PluginUniqueIdentifier,
  29. source string,
  30. meta map[string]any,
  31. on_done InstallPluginOnDoneHandler, // since installing plugin is a async task, we need to call it asynchronously
  32. ) (*InstallPluginResponse, error) {
  33. response := &InstallPluginResponse{}
  34. plugins_wait_for_installation := []plugin_entities.PluginUniqueIdentifier{}
  35. task := &models.InstallTask{
  36. Status: models.InstallTaskStatusRunning,
  37. TenantID: tenant_id,
  38. TotalPlugins: len(plugin_unique_identifiers),
  39. CompletedPlugins: 0,
  40. Plugins: []models.InstallTaskPluginStatus{},
  41. }
  42. for i, plugin_unique_identifier := range plugin_unique_identifiers {
  43. // check if plugin is already installed
  44. plugin, err := db.GetOne[models.Plugin](
  45. db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
  46. )
  47. task.Plugins = append(task.Plugins, models.InstallTaskPluginStatus{
  48. PluginUniqueIdentifier: plugin_unique_identifier,
  49. PluginID: plugin_unique_identifier.PluginID(),
  50. Status: models.InstallTaskStatusPending,
  51. Labels: plugin.Declaration.Label,
  52. Icon: plugin.Declaration.Icon,
  53. Message: "",
  54. })
  55. if err == nil {
  56. // already installed by other tenant
  57. declaration := plugin.Declaration
  58. if _, _, err := curd.InstallPlugin(
  59. tenant_id,
  60. plugin_unique_identifier,
  61. plugin.InstallType,
  62. &declaration,
  63. source,
  64. meta,
  65. ); err != nil {
  66. return nil, err
  67. }
  68. task.CompletedPlugins++
  69. task.Plugins[i].Status = models.InstallTaskStatusSuccess
  70. task.Plugins[i].Message = "Installed"
  71. continue
  72. }
  73. if err != db.ErrDatabaseNotFound {
  74. return nil, err
  75. }
  76. plugins_wait_for_installation = append(plugins_wait_for_installation, plugin_unique_identifier)
  77. }
  78. if len(plugins_wait_for_installation) == 0 {
  79. response.AllInstalled = true
  80. response.TaskID = ""
  81. return response, nil
  82. }
  83. err := db.Create(task)
  84. if err != nil {
  85. return nil, err
  86. }
  87. response.TaskID = task.ID
  88. manager := plugin_manager.Manager()
  89. tasks := []func(){}
  90. for _, plugin_unique_identifier := range plugins_wait_for_installation {
  91. // copy the variable to avoid race condition
  92. plugin_unique_identifier := plugin_unique_identifier
  93. declaration, err := manager.GetDeclaration(plugin_unique_identifier)
  94. if err != nil {
  95. return nil, err
  96. }
  97. tasks = append(tasks, func() {
  98. updateTaskStatus := func(modifier func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus)) {
  99. if err := db.WithTransaction(func(tx *gorm.DB) error {
  100. task, err := db.GetOne[models.InstallTask](
  101. db.WithTransactionContext(tx),
  102. db.Equal("id", task.ID),
  103. db.WLock(), // write lock, multiple tasks can't update the same task
  104. )
  105. if err == db.ErrDatabaseNotFound {
  106. return nil
  107. }
  108. if err != nil {
  109. return err
  110. }
  111. task_pointer := &task
  112. var plugin_status *models.InstallTaskPluginStatus
  113. for i := range task.Plugins {
  114. if task.Plugins[i].PluginUniqueIdentifier == plugin_unique_identifier {
  115. plugin_status = &task.Plugins[i]
  116. break
  117. }
  118. }
  119. if plugin_status == nil {
  120. return nil
  121. }
  122. modifier(task_pointer, plugin_status)
  123. return db.Update(task_pointer, tx)
  124. }); err != nil {
  125. log.Error("failed to update install task status %s", err.Error())
  126. }
  127. }
  128. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  129. plugin.Status = models.InstallTaskStatusRunning
  130. plugin.Message = "Installing"
  131. })
  132. var stream *stream.Stream[plugin_manager.PluginInstallResponse]
  133. if config.Platform == app.PLATFORM_AWS_LAMBDA {
  134. var zip_decoder *decoder.ZipPluginDecoder
  135. var pkg_file []byte
  136. pkg_file, err = manager.GetPackage(plugin_unique_identifier)
  137. if err != nil {
  138. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  139. task.Status = models.InstallTaskStatusFailed
  140. plugin.Status = models.InstallTaskStatusFailed
  141. plugin.Message = "Failed to read plugin package"
  142. })
  143. return
  144. }
  145. zip_decoder, err = decoder.NewZipPluginDecoder(pkg_file)
  146. if err != nil {
  147. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  148. task.Status = models.InstallTaskStatusFailed
  149. plugin.Status = models.InstallTaskStatusFailed
  150. plugin.Message = err.Error()
  151. })
  152. return
  153. }
  154. stream, err = manager.InstallToAWSFromPkg(zip_decoder, source, meta)
  155. } else if config.Platform == app.PLATFORM_LOCAL {
  156. stream, err = manager.InstallToLocal(plugin_unique_identifier, source, meta)
  157. } else {
  158. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  159. task.Status = models.InstallTaskStatusFailed
  160. plugin.Status = models.InstallTaskStatusFailed
  161. plugin.Message = "Unsupported platform"
  162. })
  163. return
  164. }
  165. if err != nil {
  166. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  167. task.Status = models.InstallTaskStatusFailed
  168. plugin.Status = models.InstallTaskStatusFailed
  169. plugin.Message = err.Error()
  170. })
  171. return
  172. }
  173. for stream.Next() {
  174. message, err := stream.Read()
  175. if err != nil {
  176. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  177. task.Status = models.InstallTaskStatusFailed
  178. plugin.Status = models.InstallTaskStatusFailed
  179. plugin.Message = err.Error()
  180. })
  181. return
  182. }
  183. if message.Event == plugin_manager.PluginInstallEventError {
  184. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  185. task.Status = models.InstallTaskStatusFailed
  186. plugin.Status = models.InstallTaskStatusFailed
  187. plugin.Message = message.Data
  188. })
  189. return
  190. }
  191. if message.Event == plugin_manager.PluginInstallEventDone {
  192. if err := on_done(plugin_unique_identifier, declaration); err != nil {
  193. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  194. task.Status = models.InstallTaskStatusFailed
  195. plugin.Status = models.InstallTaskStatusFailed
  196. plugin.Message = "Failed to create plugin"
  197. })
  198. return
  199. }
  200. }
  201. }
  202. updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
  203. plugin.Status = models.InstallTaskStatusSuccess
  204. plugin.Message = "Installed"
  205. task.CompletedPlugins++
  206. // check if all plugins are installed
  207. if task.CompletedPlugins == task.TotalPlugins {
  208. task.Status = models.InstallTaskStatusSuccess
  209. }
  210. })
  211. })
  212. }
  213. // submit async tasks
  214. routine.WithMaxRoutine(3, tasks)
  215. return response, nil
  216. }
  217. func InstallPluginFromIdentifiers(
  218. config *app.Config,
  219. tenant_id string,
  220. plugin_unique_identifiers []plugin_entities.PluginUniqueIdentifier,
  221. source string,
  222. meta map[string]any,
  223. ) *entities.Response {
  224. response, err := InstallPluginRuntimeToTenant(config, tenant_id, plugin_unique_identifiers, source, meta, func(
  225. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  226. declaration *plugin_entities.PluginDeclaration,
  227. ) error {
  228. runtime_type := plugin_entities.PluginRuntimeType("")
  229. switch config.Platform {
  230. case app.PLATFORM_AWS_LAMBDA:
  231. runtime_type = plugin_entities.PLUGIN_RUNTIME_TYPE_AWS
  232. case app.PLATFORM_LOCAL:
  233. runtime_type = plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL
  234. default:
  235. return fmt.Errorf("unsupported platform: %s", config.Platform)
  236. }
  237. _, _, err := curd.InstallPlugin(
  238. tenant_id,
  239. plugin_unique_identifier,
  240. runtime_type,
  241. declaration,
  242. source,
  243. meta,
  244. )
  245. return err
  246. })
  247. if err != nil {
  248. return entities.NewErrorResponse(-500, err.Error())
  249. }
  250. return entities.NewSuccessResponse(response)
  251. }
  252. func UpgradePlugin(
  253. config *app.Config,
  254. tenant_id string,
  255. source string,
  256. meta map[string]any,
  257. original_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  258. new_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  259. ) *entities.Response {
  260. if original_plugin_unique_identifier == new_plugin_unique_identifier {
  261. return entities.NewErrorResponse(-400, "original and new plugin unique identifier are the same")
  262. }
  263. if original_plugin_unique_identifier.PluginID() != new_plugin_unique_identifier.PluginID() {
  264. return entities.NewErrorResponse(-400, "original and new plugin id are different")
  265. }
  266. // uninstall the original plugin
  267. installation, err := db.GetOne[models.PluginInstallation](
  268. db.Equal("tenant_id", tenant_id),
  269. db.Equal("plugin_unique_identifier", original_plugin_unique_identifier.String()),
  270. db.Equal("source", source),
  271. )
  272. if err == db.ErrDatabaseNotFound {
  273. return entities.NewErrorResponse(-404, "Plugin installation not found for this tenant")
  274. }
  275. if err != nil {
  276. return entities.NewErrorResponse(-500, err.Error())
  277. }
  278. // install the new plugin runtime
  279. response, err := InstallPluginRuntimeToTenant(
  280. config,
  281. tenant_id,
  282. []plugin_entities.PluginUniqueIdentifier{new_plugin_unique_identifier},
  283. source,
  284. meta,
  285. func(
  286. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  287. declaration *plugin_entities.PluginDeclaration,
  288. ) error {
  289. // uninstall the original plugin
  290. err = curd.UpgradePlugin(
  291. tenant_id,
  292. original_plugin_unique_identifier,
  293. new_plugin_unique_identifier,
  294. declaration,
  295. plugin_entities.PluginRuntimeType(installation.RuntimeType),
  296. source,
  297. meta,
  298. )
  299. if err != nil {
  300. return err
  301. }
  302. return nil
  303. },
  304. )
  305. if err != nil {
  306. return entities.NewErrorResponse(-500, err.Error())
  307. }
  308. return entities.NewSuccessResponse(response)
  309. }
  310. func FetchPluginInstallationTasks(
  311. tenant_id string,
  312. page int,
  313. page_size int,
  314. ) *entities.Response {
  315. tasks, err := db.GetAll[models.InstallTask](
  316. db.Equal("tenant_id", tenant_id),
  317. db.OrderBy("created_at", true),
  318. db.Page(page, page_size),
  319. )
  320. if err != nil {
  321. return entities.NewErrorResponse(-500, err.Error())
  322. }
  323. return entities.NewSuccessResponse(tasks)
  324. }
  325. func FetchPluginInstallationTask(
  326. tenant_id string,
  327. task_id string,
  328. ) *entities.Response {
  329. task, err := db.GetOne[models.InstallTask](
  330. db.Equal("id", task_id),
  331. db.Equal("tenant_id", tenant_id),
  332. )
  333. if err != nil {
  334. return entities.NewErrorResponse(-500, err.Error())
  335. }
  336. return entities.NewSuccessResponse(task)
  337. }
  338. func DeletePluginInstallationTask(
  339. tenant_id string,
  340. task_id string,
  341. ) *entities.Response {
  342. err := db.DeleteByCondition(
  343. models.InstallTask{
  344. Model: models.Model{
  345. ID: task_id,
  346. },
  347. TenantID: tenant_id,
  348. },
  349. )
  350. if err != nil {
  351. return entities.NewErrorResponse(-500, err.Error())
  352. }
  353. return entities.NewSuccessResponse(true)
  354. }
  355. func DeletePluginInstallationItemFromTask(
  356. tenant_id string,
  357. task_id string,
  358. identifier plugin_entities.PluginUniqueIdentifier,
  359. ) *entities.Response {
  360. item, err := db.GetOne[models.InstallTask](
  361. db.Equal("task_id", task_id),
  362. db.Equal("tenant_id", tenant_id),
  363. )
  364. if err != nil {
  365. return entities.NewErrorResponse(-500, err.Error())
  366. }
  367. plugins := []models.InstallTaskPluginStatus{}
  368. for _, plugin := range item.Plugins {
  369. if plugin.PluginUniqueIdentifier != identifier {
  370. plugins = append(plugins, plugin)
  371. }
  372. }
  373. if len(plugins) == 0 {
  374. err = db.Delete(&item)
  375. } else {
  376. item.Plugins = plugins
  377. err = db.Update(&item)
  378. }
  379. if err != nil {
  380. return entities.NewErrorResponse(-500, err.Error())
  381. }
  382. return entities.NewSuccessResponse(true)
  383. }
  384. func FetchPluginFromIdentifier(
  385. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  386. ) *entities.Response {
  387. _, err := db.GetOne[models.Plugin](
  388. db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
  389. )
  390. if err == db.ErrDatabaseNotFound {
  391. return entities.NewSuccessResponse(false)
  392. }
  393. if err != nil {
  394. return entities.NewErrorResponse(-500, err.Error())
  395. }
  396. return entities.NewSuccessResponse(true)
  397. }
  398. func UninstallPlugin(
  399. tenant_id string,
  400. plugin_installation_id string,
  401. ) *entities.Response {
  402. // Check if the plugin exists for the tenant
  403. installation, err := db.GetOne[models.PluginInstallation](
  404. db.Equal("tenant_id", tenant_id),
  405. db.Equal("id", plugin_installation_id),
  406. )
  407. if err == db.ErrDatabaseNotFound {
  408. return entities.NewErrorResponse(-404, "Plugin installation not found for this tenant")
  409. }
  410. if err != nil {
  411. return entities.NewErrorResponse(-500, err.Error())
  412. }
  413. plugin_unique_identifier, err := plugin_entities.NewPluginUniqueIdentifier(installation.PluginUniqueIdentifier)
  414. if err != nil {
  415. return entities.NewErrorResponse(-500, fmt.Sprintf("failed to parse plugin unique identifier: %v", err))
  416. }
  417. // Uninstall the plugin
  418. delete_response, err := curd.UninstallPlugin(
  419. tenant_id,
  420. plugin_unique_identifier,
  421. installation.ID,
  422. )
  423. if err != nil {
  424. return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
  425. }
  426. if delete_response.IsPluginDeleted {
  427. // delete the plugin if no installation left
  428. manager := plugin_manager.Manager()
  429. if delete_response.Installation.RuntimeType == string(
  430. plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL,
  431. ) {
  432. err = manager.UninstallFromLocal(plugin_unique_identifier)
  433. if err != nil {
  434. return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
  435. }
  436. }
  437. }
  438. return entities.NewSuccessResponse(true)
  439. }