pgsql.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. package db
  2. import (
  3. "fmt"
  4. "gorm.io/gorm"
  5. "gorm.io/gorm/clause"
  6. )
  7. /*
  8. ORM for pgsql
  9. */
  10. var DifyPluginDB *gorm.DB
  11. var (
  12. ErrDatabaseNotFound = gorm.ErrRecordNotFound
  13. )
  14. func Create(data any, ctx ...*gorm.DB) error {
  15. if len(ctx) > 0 {
  16. return ctx[0].Create(data).Error
  17. }
  18. return DifyPluginDB.Create(data).Error
  19. }
  20. func Update(data any, ctx ...*gorm.DB) error {
  21. if len(ctx) > 0 {
  22. return ctx[0].Save(data).Error
  23. }
  24. return DifyPluginDB.Save(data).Error
  25. }
  26. func Delete(data any, ctx ...*gorm.DB) error {
  27. if len(ctx) > 0 {
  28. return ctx[0].Delete(data).Error
  29. }
  30. return DifyPluginDB.Delete(data).Error
  31. }
  32. func DeleteByCondition[T any](condition T, ctx ...*gorm.DB) error {
  33. var model T
  34. if len(ctx) > 0 {
  35. return ctx[0].Where(condition).Delete(&model).Error
  36. }
  37. return DifyPluginDB.Where(condition).Delete(&model).Error
  38. }
  39. func ReplaceAssociation[T any, R any](source *T, field string, associations []R, ctx ...*gorm.DB) error {
  40. if len(ctx) > 0 {
  41. return ctx[0].Model(source).Association(field).Replace(associations)
  42. }
  43. return DifyPluginDB.Model(source).Association(field).Replace(associations)
  44. }
  45. func AppendAssociation[T any, R any](source *T, field string, associations R, ctx ...*gorm.DB) error {
  46. if len(ctx) > 0 {
  47. return ctx[0].Model(source).Association(field).Append(associations)
  48. }
  49. return DifyPluginDB.Model(source).Association(field).Append(associations)
  50. }
  51. type genericComparableConstraint interface {
  52. int | int8 | int16 | int32 | int64 |
  53. uint | uint8 | uint16 | uint32 | uint64 |
  54. float32 | float64 |
  55. bool
  56. }
  57. type genericEqualConstraint interface {
  58. genericComparableConstraint | string
  59. }
  60. type GenericQuery func(tx *gorm.DB) *gorm.DB
  61. func Equal[T genericEqualConstraint](field string, value T) GenericQuery {
  62. return func(tx *gorm.DB) *gorm.DB {
  63. return tx.Where(fmt.Sprintf("%s = ?", field), value)
  64. }
  65. }
  66. func EqualOr[T genericEqualConstraint](field string, value T) GenericQuery {
  67. return func(tx *gorm.DB) *gorm.DB {
  68. return tx.Or(fmt.Sprintf("%s = ?", field), value)
  69. }
  70. }
  71. func NotEqual[T genericEqualConstraint](field string, value T) GenericQuery {
  72. return func(tx *gorm.DB) *gorm.DB {
  73. return tx.Where(fmt.Sprintf("%s <> ?", field), value)
  74. }
  75. }
  76. func GreaterThan[T genericComparableConstraint](field string, value T) GenericQuery {
  77. return func(tx *gorm.DB) *gorm.DB {
  78. return tx.Where(fmt.Sprintf("%s > ?", field), value)
  79. }
  80. }
  81. func GreaterThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  82. return func(tx *gorm.DB) *gorm.DB {
  83. return tx.Where(fmt.Sprintf("%s >= ?", field), value)
  84. }
  85. }
  86. func LessThan[T genericComparableConstraint](field string, value T) GenericQuery {
  87. return func(tx *gorm.DB) *gorm.DB {
  88. return tx.Where(fmt.Sprintf("%s < ?", field), value)
  89. }
  90. }
  91. func LessThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  92. return func(tx *gorm.DB) *gorm.DB {
  93. return tx.Where(fmt.Sprintf("%s <= ?", field), value)
  94. }
  95. }
  96. func Like(field string, value string) GenericQuery {
  97. return func(tx *gorm.DB) *gorm.DB {
  98. return tx.Where(fmt.Sprintf("%s LIKE ?", field), "%"+value+"%")
  99. }
  100. }
  101. func Page(page int, pageSize int) GenericQuery {
  102. return func(tx *gorm.DB) *gorm.DB {
  103. return tx.Offset((page - 1) * pageSize).Limit(pageSize)
  104. }
  105. }
  106. func OrderBy(field string, desc bool) GenericQuery {
  107. return func(tx *gorm.DB) *gorm.DB {
  108. if desc {
  109. return tx.Order(fmt.Sprintf("%s DESC", field))
  110. }
  111. return tx.Order(field)
  112. }
  113. }
  114. // bitwise operation
  115. func WithBit[T genericComparableConstraint](field string, value T) GenericQuery {
  116. return func(tx *gorm.DB) *gorm.DB {
  117. return tx.Where(fmt.Sprintf("%s & ? = ?", field), value, value)
  118. }
  119. }
  120. func WithoutBit[T genericComparableConstraint](field string, value T) GenericQuery {
  121. return func(tx *gorm.DB) *gorm.DB {
  122. return tx.Where(fmt.Sprintf("%s & ~? != 0", field), value)
  123. }
  124. }
  125. func Inc[T genericComparableConstraint](field string, value T) GenericQuery {
  126. return func(tx *gorm.DB) *gorm.DB {
  127. return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s + ?", field), value))
  128. }
  129. }
  130. func Dec[T genericComparableConstraint](field string, value T) GenericQuery {
  131. return func(tx *gorm.DB) *gorm.DB {
  132. return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s - ?", field), value))
  133. }
  134. }
  135. func Model(model any) GenericQuery {
  136. return func(tx *gorm.DB) *gorm.DB {
  137. return tx.Model(model)
  138. }
  139. }
  140. func Fields(fields ...string) GenericQuery {
  141. return func(tx *gorm.DB) *gorm.DB {
  142. return tx.Select(fields)
  143. }
  144. }
  145. func Preload(model string, args ...interface{}) GenericQuery {
  146. return func(tx *gorm.DB) *gorm.DB {
  147. return tx.Preload(model, args...)
  148. }
  149. }
  150. func Join(field string) GenericQuery {
  151. return func(tx *gorm.DB) *gorm.DB {
  152. return tx.Joins(field)
  153. }
  154. }
  155. func WLock /* write lock */ () GenericQuery {
  156. return func(tx *gorm.DB) *gorm.DB {
  157. return tx.Clauses(clause.Locking{Strength: "UPDATE"})
  158. }
  159. }
  160. func Where[T any](model *T) GenericQuery {
  161. return func(tx *gorm.DB) *gorm.DB {
  162. return tx.Where(model)
  163. }
  164. }
  165. func WhereSQL(sql string, args ...interface{}) GenericQuery {
  166. return func(tx *gorm.DB) *gorm.DB {
  167. return tx.Where(sql, args...)
  168. }
  169. }
  170. func Action(fn func(tx *gorm.DB)) GenericQuery {
  171. return func(tx *gorm.DB) *gorm.DB {
  172. fn(tx)
  173. return tx
  174. }
  175. }
  176. /*
  177. Should be used first in query chain
  178. */
  179. func WithTransactionContext(tx *gorm.DB) GenericQuery {
  180. return func(_ *gorm.DB) *gorm.DB {
  181. return tx
  182. }
  183. }
  184. func InArray(field string, value []interface{}) GenericQuery {
  185. return func(tx *gorm.DB) *gorm.DB {
  186. return tx.Where(fmt.Sprintf("%s IN ?", field), value)
  187. }
  188. }
  189. func Run(query ...GenericQuery) error {
  190. tmp := DifyPluginDB
  191. for _, q := range query {
  192. tmp = q(tmp)
  193. }
  194. // execute query
  195. return tmp.Error
  196. }
  197. func GetAny[T any](sql string, data ...interface{}) (T /* data */, error) {
  198. var result T
  199. err := DifyPluginDB.Raw(sql, data...).Scan(&result).Error
  200. return result, err
  201. }
  202. func GetOne[T any](query ...GenericQuery) (T /* data */, error) {
  203. var data T
  204. tmp := DifyPluginDB
  205. for _, q := range query {
  206. tmp = q(tmp)
  207. }
  208. err := tmp.First(&data).Error
  209. return data, err
  210. }
  211. func GetAll[T any](query ...GenericQuery) ([]T /* data */, error) {
  212. var data []T
  213. tmp := DifyPluginDB
  214. for _, q := range query {
  215. tmp = q(tmp)
  216. }
  217. err := tmp.Find(&data).Error
  218. return data, err
  219. }
  220. func GetCount[T any](query ...GenericQuery) (int64 /* count */, error) {
  221. var model T
  222. var count int64
  223. tmp := DifyPluginDB
  224. for _, q := range query {
  225. tmp = q(tmp)
  226. }
  227. err := tmp.Model(&model).Count(&count).Error
  228. return count, err
  229. }
  230. func GetSum[T any, R genericComparableConstraint](fields string, query ...GenericQuery) (R, error) {
  231. var model T
  232. var sum R
  233. tmp := DifyPluginDB
  234. for _, q := range query {
  235. tmp = q(tmp)
  236. }
  237. err := tmp.Model(&model).Select(fmt.Sprintf("SUM(%s)", fields)).Scan(&sum).Error
  238. return sum, err
  239. }
  240. func DelAssociation[T any](field string, query ...GenericQuery) error {
  241. var model T
  242. tmp := DifyPluginDB.Model(&model)
  243. for _, q := range query {
  244. tmp = q(tmp)
  245. }
  246. return tmp.Association(field).Unscoped().Clear()
  247. }
  248. func WithTransaction(fn func(tx *gorm.DB) error, ctx ...*gorm.DB) error {
  249. // Start a transaction
  250. db := DifyPluginDB
  251. if len(ctx) > 0 {
  252. db = ctx[0]
  253. }
  254. tx := db.Begin()
  255. if tx.Error != nil {
  256. return tx.Error
  257. }
  258. err := fn(tx)
  259. if err != nil {
  260. tx.Rollback()
  261. return err
  262. }
  263. tx.Commit()
  264. return nil
  265. }
  266. // NOTE: not used in production, only for testing
  267. func DropTable(model any) error {
  268. return DifyPluginDB.Migrator().DropTable(model)
  269. }
  270. // NOTE: not used in production, only for testing
  271. func CreateDatabase(dbname string) error {
  272. return DifyPluginDB.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)).Error
  273. }
  274. // NOTE: not used in production, only for testing
  275. func CreateTable(model any) error {
  276. return DifyPluginDB.Migrator().CreateTable(model)
  277. }