|
@@ -11,53 +11,53 @@ import (
|
|
|
"gorm.io/gorm"
|
|
|
)
|
|
|
|
|
|
-func initDifyPluginDB(host string, port int, db_name string, user string, pass string, sslmode string) error {
|
|
|
- // create db if not exists
|
|
|
- dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, "postgres", sslmode)
|
|
|
+func initDifyPluginDB(host string, port int, db_name string, default_db_name string, user string, pass string, sslmode string) error {
|
|
|
+ // first try to connect to target database
|
|
|
+ dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
|
|
|
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
- pgsqlDB, err := db.DB()
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
+ // if connection fails, try to create database
|
|
|
+ dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, default_db_name, sslmode)
|
|
|
+ db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
|
|
|
- // check if the db exists
|
|
|
- rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
+ pgsqlDB, err := db.DB()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer pgsqlDB.Close()
|
|
|
|
|
|
- if !rows.Next() {
|
|
|
- // create database
|
|
|
- _, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
|
|
|
+ // check if the db exists
|
|
|
+ rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- // close db
|
|
|
- err = pgsqlDB.Close()
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
+ if !rows.Next() {
|
|
|
+ // create database
|
|
|
+ _, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- // connect to the new db
|
|
|
- dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
|
|
|
- db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+ // connect to the new db
|
|
|
+ dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
|
|
|
+ db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- pgsqlDB, err = db.DB()
|
|
|
+ pgsqlDB, err := db.DB()
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// check if uuid-ossp extension exists
|
|
|
- rows, err = pgsqlDB.Query("SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp'")
|
|
|
+ rows, err := pgsqlDB.Query("SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp'")
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -134,7 +134,10 @@ func Init(config *app.Config) {
|
|
|
config.DBHost,
|
|
|
int(config.DBPort),
|
|
|
config.DBDatabase,
|
|
|
- config.DBUsername, config.DBPassword, config.DBSslMode,
|
|
|
+ "postgres",
|
|
|
+ config.DBUsername,
|
|
|
+ config.DBPassword,
|
|
|
+ config.DBSslMode,
|
|
|
)
|
|
|
|
|
|
if err != nil {
|