mysql.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package mysql
  2. import (
  3. "fmt"
  4. "time"
  5. "gorm.io/driver/mysql"
  6. "gorm.io/gorm"
  7. )
  8. func InitPluginDB(host string, port int, dbName string, defaultDbName string, user string, password string, sslMode string) (*gorm.DB, error) {
  9. initializer := mysqlDbInitializer{
  10. host: host,
  11. port: port,
  12. user: user,
  13. password: password,
  14. sslMode: sslMode,
  15. }
  16. // first try to connect to target database
  17. db, err := initializer.connect(dbName)
  18. if err != nil {
  19. // if connection fails, try to create database
  20. db, err = initializer.connect(defaultDbName)
  21. if err != nil {
  22. return nil, err
  23. }
  24. err = initializer.createDatabaseIfNotExists(db, dbName)
  25. if err != nil {
  26. return nil, err
  27. }
  28. // connect to the new db
  29. db, err = initializer.connect(dbName)
  30. if err != nil {
  31. return nil, err
  32. }
  33. }
  34. pool, err := db.DB()
  35. if err != nil {
  36. return nil, err
  37. }
  38. pool.SetConnMaxIdleTime(time.Minute * 1)
  39. return db, nil
  40. }
  41. // mysqlDbInitializer initializes database for MySQL.
  42. type mysqlDbInitializer struct {
  43. host string
  44. port int
  45. user string
  46. password string
  47. sslMode string
  48. }
  49. func (m *mysqlDbInitializer) connect(dbName string) (*gorm.DB, error) {
  50. dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&tls=%v", m.user, m.password, m.host, m.port, dbName, m.sslMode == "require")
  51. return gorm.Open(myDialector{Dialector: mysql.Open(dsn).(*mysql.Dialector)}, &gorm.Config{})
  52. }
  53. func (m *mysqlDbInitializer) createDatabaseIfNotExists(db *gorm.DB, dbName string) error {
  54. pool, err := db.DB()
  55. if err != nil {
  56. return err
  57. }
  58. defer pool.Close()
  59. rows, err := pool.Query(fmt.Sprintf("SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s'", dbName))
  60. if err != nil {
  61. return err
  62. }
  63. if !rows.Next() {
  64. // create database
  65. _, err = pool.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName))
  66. }
  67. return err
  68. }