pg.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. package pg
  2. import (
  3. "fmt"
  4. "time"
  5. "gorm.io/driver/postgres"
  6. "gorm.io/gorm"
  7. )
  8. func InitPluginDB(host string, port int, db_name string, default_db_name string, user string, pass string, sslmode string) (*gorm.DB, error) {
  9. // first try to connect to target database
  10. dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
  11. db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
  12. if err != nil {
  13. // if connection fails, try to create database
  14. dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, default_db_name, sslmode)
  15. db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
  16. if err != nil {
  17. return nil, err
  18. }
  19. pgsqlDB, err := db.DB()
  20. if err != nil {
  21. return nil, err
  22. }
  23. defer pgsqlDB.Close()
  24. // check if the db exists
  25. rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
  26. if err != nil {
  27. return nil, err
  28. }
  29. if !rows.Next() {
  30. // create database
  31. _, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
  32. if err != nil {
  33. return nil, err
  34. }
  35. }
  36. // connect to the new db
  37. dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
  38. db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
  39. if err != nil {
  40. return nil, err
  41. }
  42. }
  43. pgsqlDB, err := db.DB()
  44. if err != nil {
  45. return nil, err
  46. }
  47. // check if uuid-ossp extension exists
  48. rows, err := pgsqlDB.Query("SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp'")
  49. if err != nil {
  50. return nil, err
  51. }
  52. if !rows.Next() {
  53. // create the uuid-ossp extension
  54. _, err = pgsqlDB.Exec("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")
  55. if err != nil {
  56. return nil, err
  57. }
  58. }
  59. pgsqlDB.SetConnMaxIdleTime(time.Minute * 1)
  60. return db, nil
  61. }