pgsql.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. package db
  2. import (
  3. "fmt"
  4. "gorm.io/gorm"
  5. "gorm.io/gorm/clause"
  6. )
  7. /*
  8. ORM for pgsql
  9. */
  10. var DifyPluginDB *gorm.DB
  11. var (
  12. ErrDatabaseNotFound = gorm.ErrRecordNotFound
  13. )
  14. func Create(data any, ctx ...*gorm.DB) error {
  15. if len(ctx) > 0 {
  16. return ctx[0].Create(data).Error
  17. }
  18. return DifyPluginDB.Create(data).Error
  19. }
  20. func Update(data any, ctx ...*gorm.DB) error {
  21. if len(ctx) > 0 {
  22. return ctx[0].Save(data).Error
  23. }
  24. return DifyPluginDB.Save(data).Error
  25. }
  26. func Delete(data any, ctx ...*gorm.DB) error {
  27. if len(ctx) > 0 {
  28. return ctx[0].Delete(data).Error
  29. }
  30. return DifyPluginDB.Delete(data).Error
  31. }
  32. func ReplaceAssociation[T any, R any](source *T, field string, associations []R, ctx ...*gorm.DB) error {
  33. if len(ctx) > 0 {
  34. return ctx[0].Model(source).Association(field).Replace(associations)
  35. }
  36. return DifyPluginDB.Model(source).Association(field).Replace(associations)
  37. }
  38. func AppendAssociation[T any, R any](source *T, field string, associations R, ctx ...*gorm.DB) error {
  39. if len(ctx) > 0 {
  40. return ctx[0].Model(source).Association(field).Append(associations)
  41. }
  42. return DifyPluginDB.Model(source).Association(field).Append(associations)
  43. }
  44. type genericComparableConstraint interface {
  45. int | int8 | int16 | int32 | int64 |
  46. uint | uint8 | uint16 | uint32 | uint64 |
  47. float32 | float64 |
  48. bool
  49. }
  50. type genericEqualConstraint interface {
  51. genericComparableConstraint | string
  52. }
  53. type GenericQuery func(tx *gorm.DB) *gorm.DB
  54. func Equal[T genericEqualConstraint](field string, value T) GenericQuery {
  55. return func(tx *gorm.DB) *gorm.DB {
  56. return tx.Where(fmt.Sprintf("%s = ?", field), value)
  57. }
  58. }
  59. func EqualOr[T genericEqualConstraint](field string, value T) GenericQuery {
  60. return func(tx *gorm.DB) *gorm.DB {
  61. return tx.Or(fmt.Sprintf("%s = ?", field), value)
  62. }
  63. }
  64. func NotEqual[T genericEqualConstraint](field string, value T) GenericQuery {
  65. return func(tx *gorm.DB) *gorm.DB {
  66. return tx.Where(fmt.Sprintf("%s <> ?", field), value)
  67. }
  68. }
  69. func GreaterThan[T genericComparableConstraint](field string, value T) GenericQuery {
  70. return func(tx *gorm.DB) *gorm.DB {
  71. return tx.Where(fmt.Sprintf("%s > ?", field), value)
  72. }
  73. }
  74. func GreaterThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  75. return func(tx *gorm.DB) *gorm.DB {
  76. return tx.Where(fmt.Sprintf("%s >= ?", field), value)
  77. }
  78. }
  79. func LessThan[T genericComparableConstraint](field string, value T) GenericQuery {
  80. return func(tx *gorm.DB) *gorm.DB {
  81. return tx.Where(fmt.Sprintf("%s < ?", field), value)
  82. }
  83. }
  84. func LessThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  85. return func(tx *gorm.DB) *gorm.DB {
  86. return tx.Where(fmt.Sprintf("%s <= ?", field), value)
  87. }
  88. }
  89. func Like(field string, value string) GenericQuery {
  90. return func(tx *gorm.DB) *gorm.DB {
  91. return tx.Where(fmt.Sprintf("%s LIKE ?", field), "%"+value+"%")
  92. }
  93. }
  94. func Page(page int, pageSize int) GenericQuery {
  95. return func(tx *gorm.DB) *gorm.DB {
  96. return tx.Offset((page - 1) * pageSize).Limit(pageSize)
  97. }
  98. }
  99. func OrderBy(field string, desc bool) GenericQuery {
  100. return func(tx *gorm.DB) *gorm.DB {
  101. if desc {
  102. return tx.Order(fmt.Sprintf("%s DESC", field))
  103. }
  104. return tx.Order(field)
  105. }
  106. }
  107. // bitwise operation
  108. func WithBit[T genericComparableConstraint](field string, value T) GenericQuery {
  109. return func(tx *gorm.DB) *gorm.DB {
  110. return tx.Where(fmt.Sprintf("%s & ? = ?", field), value, value)
  111. }
  112. }
  113. func WithoutBit[T genericComparableConstraint](field string, value T) GenericQuery {
  114. return func(tx *gorm.DB) *gorm.DB {
  115. return tx.Where(fmt.Sprintf("%s & ~? != 0", field), value)
  116. }
  117. }
  118. func Inc[T genericComparableConstraint](field string, value T) GenericQuery {
  119. return func(tx *gorm.DB) *gorm.DB {
  120. return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s + ?", field), value))
  121. }
  122. }
  123. func Dec[T genericComparableConstraint](field string, value T) GenericQuery {
  124. return func(tx *gorm.DB) *gorm.DB {
  125. return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s - ?", field), value))
  126. }
  127. }
  128. func Model(model any) GenericQuery {
  129. return func(tx *gorm.DB) *gorm.DB {
  130. return tx.Model(model)
  131. }
  132. }
  133. func Fields(fields ...string) GenericQuery {
  134. return func(tx *gorm.DB) *gorm.DB {
  135. return tx.Select(fields)
  136. }
  137. }
  138. func Preload(model string, args ...interface{}) GenericQuery {
  139. return func(tx *gorm.DB) *gorm.DB {
  140. return tx.Preload(model, args...)
  141. }
  142. }
  143. func Join(field string) GenericQuery {
  144. return func(tx *gorm.DB) *gorm.DB {
  145. return tx.Joins(field)
  146. }
  147. }
  148. func WLock /* write lock */ () GenericQuery {
  149. return func(tx *gorm.DB) *gorm.DB {
  150. return tx.Clauses(clause.Locking{Strength: "UPDATE"})
  151. }
  152. }
  153. func Where[T any](model *T) GenericQuery {
  154. return func(tx *gorm.DB) *gorm.DB {
  155. return tx.Where(model)
  156. }
  157. }
  158. func WhereSQL(sql string, args ...interface{}) GenericQuery {
  159. return func(tx *gorm.DB) *gorm.DB {
  160. return tx.Where(sql, args...)
  161. }
  162. }
  163. func Action(fn func(tx *gorm.DB)) GenericQuery {
  164. return func(tx *gorm.DB) *gorm.DB {
  165. fn(tx)
  166. return tx
  167. }
  168. }
  169. /*
  170. Should be used first in query chain
  171. */
  172. func WithTransactionContext(tx *gorm.DB) GenericQuery {
  173. return func(_ *gorm.DB) *gorm.DB {
  174. return tx
  175. }
  176. }
  177. func InArray(field string, value []interface{}) GenericQuery {
  178. return func(tx *gorm.DB) *gorm.DB {
  179. return tx.Where(fmt.Sprintf("%s IN ?", field), value)
  180. }
  181. }
  182. func Run(query ...GenericQuery) error {
  183. tmp := DifyPluginDB
  184. for _, q := range query {
  185. tmp = q(tmp)
  186. }
  187. // execute query
  188. return tmp.Error
  189. }
  190. func GetAny[T any](sql string, data ...interface{}) (T /* data */, error) {
  191. var result T
  192. err := DifyPluginDB.Raw(sql, data...).Scan(&result).Error
  193. return result, err
  194. }
  195. func GetOne[T any](query ...GenericQuery) (T /* data */, error) {
  196. var data T
  197. tmp := DifyPluginDB
  198. for _, q := range query {
  199. tmp = q(tmp)
  200. }
  201. err := tmp.First(&data).Error
  202. return data, err
  203. }
  204. func GetAll[T any](query ...GenericQuery) ([]T /* data */, error) {
  205. var data []T
  206. tmp := DifyPluginDB
  207. for _, q := range query {
  208. tmp = q(tmp)
  209. }
  210. err := tmp.Find(&data).Error
  211. return data, err
  212. }
  213. func GetCount[T any](query ...GenericQuery) (int64 /* count */, error) {
  214. var model T
  215. var count int64
  216. tmp := DifyPluginDB
  217. for _, q := range query {
  218. tmp = q(tmp)
  219. }
  220. err := tmp.Model(&model).Count(&count).Error
  221. return count, err
  222. }
  223. func GetSum[T any, R genericComparableConstraint](fields string, query ...GenericQuery) (R, error) {
  224. var model T
  225. var sum R
  226. tmp := DifyPluginDB
  227. for _, q := range query {
  228. tmp = q(tmp)
  229. }
  230. err := tmp.Model(&model).Select(fmt.Sprintf("SUM(%s)", fields)).Scan(&sum).Error
  231. return sum, err
  232. }
  233. func DelAssociation[T any](field string, query ...GenericQuery) error {
  234. var model T
  235. tmp := DifyPluginDB.Model(&model)
  236. for _, q := range query {
  237. tmp = q(tmp)
  238. }
  239. return tmp.Association(field).Unscoped().Clear()
  240. }
  241. func WithTransaction(fn func(tx *gorm.DB) error, ctx ...*gorm.DB) error {
  242. // Start a transaction
  243. db := DifyPluginDB
  244. if len(ctx) > 0 {
  245. db = ctx[0]
  246. }
  247. tx := db.Begin()
  248. if tx.Error != nil {
  249. return tx.Error
  250. }
  251. err := fn(tx)
  252. if err != nil {
  253. tx.Rollback()
  254. return err
  255. }
  256. tx.Commit()
  257. return nil
  258. }
  259. // NOTE: not used in production, only for testing
  260. func DropTable(model any) error {
  261. return DifyPluginDB.Migrator().DropTable(model)
  262. }
  263. // NOTE: not used in production, only for testing
  264. func CreateDatabase(dbname string) error {
  265. return DifyPluginDB.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)).Error
  266. }
  267. // NOTE: not used in production, only for testing
  268. func CreateTable(model any) error {
  269. return DifyPluginDB.Migrator().CreateTable(model)
  270. }