Browse Source

feat: add db

Yeuoly 1 year ago
parent
commit
ce4ce368e2
4 changed files with 34 additions and 7 deletions
  1. 1 1
      cmd/server/main.go
  2. 28 6
      internal/db/init.go
  3. 4 0
      internal/server/server.go
  4. 1 0
      internal/types/app/config.go

+ 1 - 1
cmd/server/main.go

@@ -41,7 +41,7 @@ func setDefault(config *app.Config) {
 	setDefaultInt(&config.PluginRemoteInstallServerEventLoopNums, 8)
 	setDefaultInt(&config.PluginRemoteInstallingMaxConn, 128)
 	settDefaultBool(&config.PluginRemoteInstallingEnabled, true)
-
+	settDefaultString(&config.DBSslMode, "disable")
 	settDefaultString(&config.ProcessCachingPath, "/tmp/dify-plugin-daemon-subprocesses")
 }
 

+ 28 - 6
internal/db/init.go

@@ -4,13 +4,15 @@ import (
 	"fmt"
 	"time"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"gorm.io/driver/postgres"
 	"gorm.io/gorm"
 )
 
-func InitDifyEnterpriseDB(host string, port int, dbname string, user string, pass string, sslmode string) error {
+func initDifyPluginDB(host string, port int, db_name string, user string, pass string, sslmode string) error {
 	// create db if not exists
-	dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, dbname, sslmode)
+	dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, "postgres", sslmode)
 	db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
 	if err != nil {
 		return err
@@ -22,14 +24,14 @@ func InitDifyEnterpriseDB(host string, port int, dbname string, user string, pas
 	}
 
 	// check if the db exists
-	rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", dbname))
+	rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
 	if err != nil {
 		return err
 	}
 
 	if !rows.Next() {
 		// create database
-		_, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname))
+		_, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
 		if err != nil {
 			return err
 		}
@@ -42,7 +44,7 @@ func InitDifyEnterpriseDB(host string, port int, dbname string, user string, pas
 	}
 
 	// connect to the new db
-	dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, dbname, sslmode)
+	dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
 	db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
 	if err != nil {
 		return err
@@ -73,6 +75,26 @@ func InitDifyEnterpriseDB(host string, port int, dbname string, user string, pas
 	return nil
 }
 
-func AutoMigrate() error {
+func autoMigrate() error {
 	return DifyPluginDB.AutoMigrate()
 }
+
+func Init(config *app.Config) {
+	err := initDifyPluginDB(
+		config.DBHost,
+		int(config.DBPort),
+		config.DBDatabase,
+		config.DBUsername, config.DBPassword, config.DBSslMode,
+	)
+
+	if err != nil {
+		log.Panic("failed to init dify plugin db: %v", err)
+	}
+
+	err = autoMigrate()
+	if err != nil {
+		log.Panic("failed to auto migrate: %v", err)
+	}
+
+	log.Info("dify plugin db initialized")
+}

+ 4 - 0
internal/server/server.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/db"
 	"github.com/langgenius/dify-plugin-daemon/internal/process"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
@@ -11,6 +12,9 @@ func Run(config *app.Config) {
 	// init routine pool
 	routine.InitPool(config.RoutinePoolSize)
 
+	// init db
+	db.Init(config)
+
 	// init process lifetime
 	process.Init(config)
 

+ 1 - 0
internal/types/app/config.go

@@ -34,6 +34,7 @@ type Config struct {
 	DBHost     string `envconfig:"DB_HOST" validate:"required"`
 	DBPort     int16  `envconfig:"DB_PORT" validate:"required"`
 	DBDatabase string `envconfig:"DB_DATABASE" validate:"required"`
+	DBSslMode  string `envconfig:"DB_SSL_MODE" validate:"required,oneof=disable require"`
 
 	LifetimeCollectionHeartbeatInterval int `envconfig:"LIFETIME_COLLECTION_HEARTBEAT_INTERVAL"  validate:"required"`
 	LifetimeCollectionGCInterval        int `envconfig:"LIFETIME_COLLECTION_GC_INTERVAL" validate:"required"`