Browse Source

feat: enhance database initialization with fallback connection strategy (#58)

- Add support for connecting to a default database when target database connection fails
- Introduce new configuration parameter `DBDefaultDatabase` with default value "postgres"
- Modify `initDifyPluginDB` to attempt connection to default database and create target database if needed
- Update test case to include default database parameter
Yeuoly 4 months ago
parent
commit
964549d31f
4 changed files with 44 additions and 39 deletions
  1. 35 32
      internal/db/init.go
  2. 1 1
      internal/db/pg_test.go
  3. 7 6
      internal/types/app/config.go
  4. 1 0
      internal/types/app/default.go

+ 35 - 32
internal/db/init.go

@@ -11,53 +11,53 @@ import (
 	"gorm.io/gorm"
 )
 
-func initDifyPluginDB(host string, port int, db_name string, user string, pass string, sslmode string) error {
-	// create db if not exists
-	dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, "postgres", sslmode)
+func initDifyPluginDB(host string, port int, db_name string, default_db_name string, user string, pass string, sslmode string) error {
+	// first try to connect to target database
+	dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
 	db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
 	if err != nil {
-		return err
-	}
-
-	pgsqlDB, err := db.DB()
-	if err != nil {
-		return err
-	}
+		// if connection fails, try to create database
+		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, default_db_name, sslmode)
+		db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
+		if err != nil {
+			return err
+		}
 
-	// check if the db exists
-	rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
-	if err != nil {
-		return err
-	}
+		pgsqlDB, err := db.DB()
+		if err != nil {
+			return err
+		}
+		defer pgsqlDB.Close()
 
-	if !rows.Next() {
-		// create database
-		_, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
+		// check if the db exists
+		rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
 		if err != nil {
 			return err
 		}
-	}
 
-	// close db
-	err = pgsqlDB.Close()
-	if err != nil {
-		return err
-	}
+		if !rows.Next() {
+			// create database
+			_, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
+			if err != nil {
+				return err
+			}
+		}
 
-	// connect to the new db
-	dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
-	db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
-	if err != nil {
-		return err
+		// connect to the new db
+		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
+		db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
+		if err != nil {
+			return err
+		}
 	}
 
-	pgsqlDB, err = db.DB()
+	pgsqlDB, err := db.DB()
 	if err != nil {
 		return err
 	}
 
 	// check if uuid-ossp extension exists
-	rows, err = pgsqlDB.Query("SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp'")
+	rows, err := pgsqlDB.Query("SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp'")
 	if err != nil {
 		return err
 	}
@@ -134,7 +134,10 @@ func Init(config *app.Config) {
 		config.DBHost,
 		int(config.DBPort),
 		config.DBDatabase,
-		config.DBUsername, config.DBPassword, config.DBSslMode,
+		"postgres",
+		config.DBUsername,
+		config.DBPassword,
+		config.DBSslMode,
 	)
 
 	if err != nil {

+ 1 - 1
internal/db/pg_test.go

@@ -8,7 +8,7 @@ import (
 )
 
 func TestTransaction(t *testing.T) {
-	if err := initDifyPluginDB("0.0.0.0", 5432, "testing", "postgres", "difyai123456", "disable"); err != nil {
+	if err := initDifyPluginDB("0.0.0.0", 5432, "testing", "postgres", "postgres", "difyai123456", "disable"); err != nil {
 		t.Fatal(err)
 	}
 	defer Close()

+ 7 - 6
internal/types/app/config.go

@@ -60,12 +60,13 @@ type Config struct {
 	RedisUseSsl bool   `envconfig:"REDIS_USE_SSL"`
 
 	// database
-	DBUsername string `envconfig:"DB_USERNAME" validate:"required"`
-	DBPassword string `envconfig:"DB_PASSWORD" validate:"required"`
-	DBHost     string `envconfig:"DB_HOST" validate:"required"`
-	DBPort     uint16 `envconfig:"DB_PORT" validate:"required"`
-	DBDatabase string `envconfig:"DB_DATABASE" validate:"required"`
-	DBSslMode  string `envconfig:"DB_SSL_MODE" validate:"required,oneof=disable require"`
+	DBUsername        string `envconfig:"DB_USERNAME" validate:"required"`
+	DBPassword        string `envconfig:"DB_PASSWORD" validate:"required"`
+	DBHost            string `envconfig:"DB_HOST" validate:"required"`
+	DBPort            uint16 `envconfig:"DB_PORT" validate:"required"`
+	DBDatabase        string `envconfig:"DB_DATABASE" validate:"required"`
+	DBDefaultDatabase string `envconfig:"DB_DEFAULT_DATABASE" validate:"required"`
+	DBSslMode         string `envconfig:"DB_SSL_MODE" validate:"required,oneof=disable require"`
 
 	// persistence storage
 	PersistenceStoragePath    string `envconfig:"PERSISTENCE_STORAGE_PATH"`

+ 1 - 0
internal/types/app/default.go

@@ -31,6 +31,7 @@ func (config *Config) SetDefault() {
 	setDefaultString(&config.PythonInterpreterPath, "/usr/bin/python3")
 	setDefaultInt(&config.PythonEnvInitTimeout, 120)
 	setDefaultBoolPtr(&config.ForceVerifyingSignature, true)
+	setDefaultString(&config.DBDefaultDatabase, "postgres")
 }
 
 func setDefaultInt[T constraints.Integer](value *T, defaultValue T) {