Bladeren bron

support to use mysql as plugin db

He Wang 7 maanden geleden
bovenliggende
commit
d4aefed92c

+ 2 - 0
go.mod

@@ -22,6 +22,7 @@ require (
 	github.com/tencentyun/cos-go-sdk-v5 v0.7.62
 	github.com/vmihailenco/msgpack/v5 v5.4.1
 	github.com/xeipuuv/gojsonschema v1.2.0
+	gorm.io/driver/mysql v1.5.7
 	gorm.io/gorm v1.25.11
 )
 
@@ -50,6 +51,7 @@ require (
 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
 	github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
 	github.com/fsnotify/fsnotify v1.7.0 // indirect
+	github.com/go-sql-driver/mysql v1.7.0 // indirect
 	github.com/google/go-querystring v1.1.0 // indirect
 	github.com/hashicorp/hcl v1.0.0 // indirect
 	github.com/inconshreveable/mousetrap v1.1.0 // indirect

+ 5 - 0
go.sum

@@ -109,6 +109,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
 github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
 github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4Bx7ia+JlgcnOao=
 github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
+github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
+github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
 github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
 github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
 github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
@@ -338,8 +340,11 @@ gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRN
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
+gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
 gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8=
 gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI=
+gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
 gorm.io/gorm v1.25.11 h1:/Wfyg1B/je1hnDx3sMkX+gAlxrlZpn6X0BXRlwXlvHg=
 gorm.io/gorm v1.25.11/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
 nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=

internal/db/pgsql.go → internal/db/executor.go


+ 4 - 2
internal/db/pg_test.go

@@ -4,11 +4,13 @@ import (
 	"errors"
 	"testing"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/db/pg"
 	"gorm.io/gorm"
 )
 
 func TestTransaction(t *testing.T) {
-	if err := initDifyPluginDB("0.0.0.0", 5432, "testing", "postgres", "postgres", "difyai123456", "disable"); err != nil {
+	var err error
+	if DifyPluginDB, err = pg.InitPluginDB("0.0.0.0", 5432, "testing", "postgres", "postgres", "difyai123456", "disable"); err != nil {
 		t.Fatal(err)
 	}
 	defer Close()
@@ -17,7 +19,7 @@ func TestTransaction(t *testing.T) {
 		gorm.Model
 	}
 
-	err := CreateTable(&TestTable{})
+	err = CreateTable(&TestTable{})
 	if err != nil {
 		t.Fatal(err)
 	}

+ 26 - 79
internal/db/init.go

@@ -1,81 +1,13 @@
 package db
 
 import (
-	"fmt"
-	"time"
-
+	"github.com/langgenius/dify-plugin-daemon/internal/db/mysql"
+	"github.com/langgenius/dify-plugin-daemon/internal/db/pg"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
-	"gorm.io/driver/postgres"
-	"gorm.io/gorm"
 )
 
-func initDifyPluginDB(host string, port int, db_name string, default_db_name string, user string, pass string, sslmode string) error {
-	// first try to connect to target database
-	dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
-	db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
-	if err != nil {
-		// if connection fails, try to create database
-		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, default_db_name, sslmode)
-		db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
-		if err != nil {
-			return err
-		}
-
-		pgsqlDB, err := db.DB()
-		if err != nil {
-			return err
-		}
-		defer pgsqlDB.Close()
-
-		// check if the db exists
-		rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
-		if err != nil {
-			return err
-		}
-
-		if !rows.Next() {
-			// create database
-			_, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
-			if err != nil {
-				return err
-			}
-		}
-
-		// connect to the new db
-		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
-		db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
-		if err != nil {
-			return err
-		}
-	}
-
-	pgsqlDB, err := db.DB()
-	if err != nil {
-		return err
-	}
-
-	// check if uuid-ossp extension exists
-	rows, err := pgsqlDB.Query("SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp'")
-	if err != nil {
-		return err
-	}
-
-	if !rows.Next() {
-		// create the uuid-ossp extension
-		_, err = pgsqlDB.Exec("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")
-		if err != nil {
-			return err
-		}
-	}
-
-	pgsqlDB.SetConnMaxIdleTime(time.Minute * 1)
-	DifyPluginDB = db
-
-	return nil
-}
-
 func autoMigrate() error {
 	err := DifyPluginDB.AutoMigrate(
 		models.Plugin{},
@@ -130,15 +62,30 @@ func autoMigrate() error {
 }
 
 func Init(config *app.Config) {
-	err := initDifyPluginDB(
-		config.DBHost,
-		int(config.DBPort),
-		config.DBDatabase,
-		config.DBDefaultDatabase,
-		config.DBUsername,
-		config.DBPassword,
-		config.DBSslMode,
-	)
+	var err error
+	if config.DBType == "postgresql" {
+		DifyPluginDB, err = pg.InitPluginDB(
+			config.DBHost,
+			int(config.DBPort),
+			config.DBDatabase,
+			config.DBDefaultDatabase,
+			config.DBUsername,
+			config.DBPassword,
+			config.DBSslMode,
+		)
+	} else if config.DBType == "mysql" {
+		DifyPluginDB, err = mysql.InitPluginDB(
+			config.DBHost,
+			int(config.DBPort),
+			config.DBDatabase,
+			config.DBDefaultDatabase,
+			config.DBUsername,
+			config.DBPassword,
+			config.DBSslMode,
+		)
+	} else {
+		log.Panic("unsupported database type: %v", config.DBType)
+	}
 
 	if err != nil {
 		log.Panic("failed to init dify plugin db: %v", err)

+ 45 - 0
internal/db/mysql/dialector.go

@@ -0,0 +1,45 @@
+package mysql
+
+import (
+	"gorm.io/driver/mysql"
+	"gorm.io/gorm"
+	"gorm.io/gorm/clause"
+	"gorm.io/gorm/schema"
+)
+
+type myDialector struct {
+	*mysql.Dialector
+}
+
+func (dialector myDialector) Migrator(db *gorm.DB) gorm.Migrator {
+	return myMigrator{dialector.Dialector.Migrator(db).(mysql.Migrator)}
+}
+
+func (dialector myDialector) DataTypeOf(field *schema.Field) string {
+	dataType := dialector.Dialector.DataTypeOf(field)
+	switch dataType {
+	case "uuid":
+		return "char(36)"
+	case "text":
+		return "longtext"
+	default:
+		return dataType
+	}
+}
+
+type myMigrator struct {
+	mysql.Migrator
+}
+
+func (migrator myMigrator) FullDataTypeOf(field *schema.Field) clause.Expr {
+	if field.DataType == "uuid" {
+		field.DataType = "char(36)"
+		if field.HasDefaultValue && field.DefaultValue == "uuid_generate_v4()" {
+			field.HasDefaultValue = false
+			field.DefaultValue = ""
+		}
+	} else if field.DataType == "text" {
+		field.DataType = "longtext"
+	}
+	return migrator.Migrator.FullDataTypeOf(field)
+}

+ 82 - 0
internal/db/mysql/mysql.go

@@ -0,0 +1,82 @@
+package mysql
+
+import (
+	"fmt"
+	"time"
+
+	"gorm.io/driver/mysql"
+	"gorm.io/gorm"
+)
+
+func InitPluginDB(host string, port int, dbName string, defaultDbName string, user string, password string, sslMode string) (*gorm.DB, error) {
+	initializer := mysqlDbInitializer{
+		host:     host,
+		port:     port,
+		user:     user,
+		password: password,
+		sslMode:  sslMode,
+	}
+
+	// first try to connect to target database
+	db, err := initializer.connect(dbName)
+	if err != nil {
+		// if connection fails, try to create database
+		db, err = initializer.connect(defaultDbName)
+		if err != nil {
+			return nil, err
+		}
+
+		err = initializer.createDatabaseIfNotExists(db, dbName)
+		if err != nil {
+			return nil, err
+		}
+
+		// connect to the new db
+		db, err = initializer.connect(dbName)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	pool, err := db.DB()
+	if err != nil {
+		return nil, err
+	}
+
+	pool.SetConnMaxIdleTime(time.Minute * 1)
+
+	return db, nil
+}
+
+// mysqlDbInitializer initializes database for MySQL.
+type mysqlDbInitializer struct {
+	host     string
+	port     int
+	user     string
+	password string
+	sslMode  string
+}
+
+func (m *mysqlDbInitializer) connect(dbName string) (*gorm.DB, error) {
+	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&tls=%v", m.user, m.password, m.host, m.port, dbName, m.sslMode == "require")
+	return gorm.Open(myDialector{Dialector: mysql.Open(dsn).(*mysql.Dialector)}, &gorm.Config{})
+}
+
+func (m *mysqlDbInitializer) createDatabaseIfNotExists(db *gorm.DB, dbName string) error {
+	pool, err := db.DB()
+	if err != nil {
+		return err
+	}
+	defer pool.Close()
+
+	rows, err := pool.Query(fmt.Sprintf("SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s'", dbName))
+	if err != nil {
+		return err
+	}
+
+	if !rows.Next() {
+		// create database
+		_, err = pool.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName))
+	}
+	return err
+}

+ 73 - 0
internal/db/pg/pg.go

@@ -0,0 +1,73 @@
+package pg
+
+import (
+	"fmt"
+	"time"
+
+	"gorm.io/driver/postgres"
+	"gorm.io/gorm"
+)
+
+func InitPluginDB(host string, port int, db_name string, default_db_name string, user string, pass string, sslmode string) (*gorm.DB, error) {
+	// first try to connect to target database
+	dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
+	db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
+	if err != nil {
+		// if connection fails, try to create database
+		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, default_db_name, sslmode)
+		db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
+		if err != nil {
+			return nil, err
+		}
+
+		pgsqlDB, err := db.DB()
+		if err != nil {
+			return nil, err
+		}
+		defer pgsqlDB.Close()
+
+		// check if the db exists
+		rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
+		if err != nil {
+			return nil, err
+		}
+
+		if !rows.Next() {
+			// create database
+			_, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
+			if err != nil {
+				return nil, err
+			}
+		}
+
+		// connect to the new db
+		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
+		db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	pgsqlDB, err := db.DB()
+	if err != nil {
+		return nil, err
+	}
+
+	// check if uuid-ossp extension exists
+	rows, err := pgsqlDB.Query("SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp'")
+	if err != nil {
+		return nil, err
+	}
+
+	if !rows.Next() {
+		// create the uuid-ossp extension
+		_, err = pgsqlDB.Exec("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	pgsqlDB.SetConnMaxIdleTime(time.Minute * 1)
+
+	return db, nil
+}

+ 1 - 0
internal/types/app/config.go

@@ -66,6 +66,7 @@ type Config struct {
 	RedisUseSsl bool   `envconfig:"REDIS_USE_SSL"`
 
 	// database
+	DBType            string `envconfig:"DB_TYPE" default:"postgresql"`
 	DBUsername        string `envconfig:"DB_USERNAME" validate:"required"`
 	DBPassword        string `envconfig:"DB_PASSWORD" validate:"required"`
 	DBHost            string `envconfig:"DB_HOST" validate:"required"`

+ 10 - 0
internal/types/models/base.go

@@ -2,6 +2,9 @@ package models
 
 import (
 	"time"
+
+	"github.com/google/uuid"
+	"gorm.io/gorm"
 )
 
 type Model struct {
@@ -9,3 +12,10 @@ type Model struct {
 	CreatedAt time.Time `json:"created_at"`
 	UpdatedAt time.Time `json:"updated_at"`
 }
+
+func (m *Model) BeforeCreate(tx *gorm.DB) (err error) {
+	if tx.Dialector.Name() == "mysql" && m.ID == "" {
+		m.ID = uuid.New().String()
+	}
+	return
+}