pgsql.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. package db
  2. import (
  3. "fmt"
  4. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. )
  8. /*
  9. ORM for pgsql
  10. */
  11. var DifyPluginDB *gorm.DB
  12. var (
  13. ErrDatabaseNotFound = gorm.ErrRecordNotFound
  14. )
  15. func Create(data any, ctx ...*gorm.DB) error {
  16. if len(ctx) > 0 {
  17. return ctx[0].Create(data).Error
  18. }
  19. return DifyPluginDB.Create(data).Error
  20. }
  21. func Update(data any, ctx ...*gorm.DB) error {
  22. if len(ctx) > 0 {
  23. return ctx[0].Save(data).Error
  24. }
  25. return DifyPluginDB.Save(data).Error
  26. }
  27. func Delete(data any, ctx ...*gorm.DB) error {
  28. if len(ctx) > 0 {
  29. return ctx[0].Delete(data).Error
  30. }
  31. return DifyPluginDB.Delete(data).Error
  32. }
  33. func DeleteByCondition[T any](condition T, ctx ...*gorm.DB) error {
  34. var model T
  35. if len(ctx) > 0 {
  36. return ctx[0].Where(condition).Delete(&model).Error
  37. }
  38. return DifyPluginDB.Where(condition).Delete(&model).Error
  39. }
  40. func ReplaceAssociation[T any, R any](source *T, field string, associations []R, ctx ...*gorm.DB) error {
  41. if len(ctx) > 0 {
  42. return ctx[0].Model(source).Association(field).Replace(associations)
  43. }
  44. return DifyPluginDB.Model(source).Association(field).Replace(associations)
  45. }
  46. func AppendAssociation[T any, R any](source *T, field string, associations R, ctx ...*gorm.DB) error {
  47. if len(ctx) > 0 {
  48. return ctx[0].Model(source).Association(field).Append(associations)
  49. }
  50. return DifyPluginDB.Model(source).Association(field).Append(associations)
  51. }
  52. type genericComparableConstraint interface {
  53. int | int8 | int16 | int32 | int64 |
  54. uint | uint8 | uint16 | uint32 | uint64 |
  55. float32 | float64 |
  56. bool
  57. }
  58. type genericEqualConstraint interface {
  59. genericComparableConstraint | string
  60. }
  61. type GenericQuery func(tx *gorm.DB) *gorm.DB
  62. func Equal[T genericEqualConstraint](field string, value T) GenericQuery {
  63. return func(tx *gorm.DB) *gorm.DB {
  64. return tx.Where(fmt.Sprintf("%s = ?", field), value)
  65. }
  66. }
  67. func EqualOr[T genericEqualConstraint](field string, value T) GenericQuery {
  68. return func(tx *gorm.DB) *gorm.DB {
  69. return tx.Or(fmt.Sprintf("%s = ?", field), value)
  70. }
  71. }
  72. func NotEqual[T genericEqualConstraint](field string, value T) GenericQuery {
  73. return func(tx *gorm.DB) *gorm.DB {
  74. return tx.Where(fmt.Sprintf("%s <> ?", field), value)
  75. }
  76. }
  77. func GreaterThan[T genericComparableConstraint](field string, value T) GenericQuery {
  78. return func(tx *gorm.DB) *gorm.DB {
  79. return tx.Where(fmt.Sprintf("%s > ?", field), value)
  80. }
  81. }
  82. func GreaterThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  83. return func(tx *gorm.DB) *gorm.DB {
  84. return tx.Where(fmt.Sprintf("%s >= ?", field), value)
  85. }
  86. }
  87. func LessThan[T genericComparableConstraint](field string, value T) GenericQuery {
  88. return func(tx *gorm.DB) *gorm.DB {
  89. return tx.Where(fmt.Sprintf("%s < ?", field), value)
  90. }
  91. }
  92. func LessThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  93. return func(tx *gorm.DB) *gorm.DB {
  94. return tx.Where(fmt.Sprintf("%s <= ?", field), value)
  95. }
  96. }
  97. func Like(field string, value string) GenericQuery {
  98. return func(tx *gorm.DB) *gorm.DB {
  99. return tx.Where(fmt.Sprintf("%s LIKE ?", field), "%"+value+"%")
  100. }
  101. }
  102. func Page(page int, pageSize int) GenericQuery {
  103. return func(tx *gorm.DB) *gorm.DB {
  104. return tx.Offset((page - 1) * pageSize).Limit(pageSize)
  105. }
  106. }
  107. func OrderBy(field string, desc bool) GenericQuery {
  108. return func(tx *gorm.DB) *gorm.DB {
  109. if desc {
  110. return tx.Order(fmt.Sprintf("%s DESC", field))
  111. }
  112. return tx.Order(field)
  113. }
  114. }
  115. // bitwise operation
  116. func WithBit[T genericComparableConstraint](field string, value T) GenericQuery {
  117. return func(tx *gorm.DB) *gorm.DB {
  118. return tx.Where(fmt.Sprintf("%s & ? = ?", field), value, value)
  119. }
  120. }
  121. func WithoutBit[T genericComparableConstraint](field string, value T) GenericQuery {
  122. return func(tx *gorm.DB) *gorm.DB {
  123. return tx.Where(fmt.Sprintf("%s & ~? != 0", field), value)
  124. }
  125. }
  126. func Inc[T genericComparableConstraint](updates map[string]T) GenericQuery {
  127. return func(tx *gorm.DB) *gorm.DB {
  128. m := make(map[string]any)
  129. for field, value := range updates {
  130. m[field] = gorm.Expr(fmt.Sprintf("%s + ?", field), value)
  131. }
  132. return tx.UpdateColumns(m)
  133. }
  134. }
  135. func Dec[T genericComparableConstraint](updates map[string]T) GenericQuery {
  136. return func(tx *gorm.DB) *gorm.DB {
  137. m := make(map[string]any)
  138. for field, value := range updates {
  139. m[field] = gorm.Expr(fmt.Sprintf("%s - ?", field), value)
  140. }
  141. return tx.UpdateColumns(m)
  142. }
  143. }
  144. func Model(model any) GenericQuery {
  145. return func(tx *gorm.DB) *gorm.DB {
  146. return tx.Model(model)
  147. }
  148. }
  149. func Fields(fields ...string) GenericQuery {
  150. return func(tx *gorm.DB) *gorm.DB {
  151. return tx.Select(fields)
  152. }
  153. }
  154. func Preload(model string, args ...interface{}) GenericQuery {
  155. return func(tx *gorm.DB) *gorm.DB {
  156. return tx.Preload(model, args...)
  157. }
  158. }
  159. func Join(field string) GenericQuery {
  160. return func(tx *gorm.DB) *gorm.DB {
  161. return tx.Joins(field)
  162. }
  163. }
  164. func WLock /* write lock */ () GenericQuery {
  165. return func(tx *gorm.DB) *gorm.DB {
  166. return tx.Clauses(clause.Locking{Strength: "UPDATE"})
  167. }
  168. }
  169. func Where[T any](model *T) GenericQuery {
  170. return func(tx *gorm.DB) *gorm.DB {
  171. return tx.Where(model)
  172. }
  173. }
  174. func WhereSQL(sql string, args ...interface{}) GenericQuery {
  175. return func(tx *gorm.DB) *gorm.DB {
  176. return tx.Where(sql, args...)
  177. }
  178. }
  179. func Action(fn func(tx *gorm.DB)) GenericQuery {
  180. return func(tx *gorm.DB) *gorm.DB {
  181. fn(tx)
  182. return tx
  183. }
  184. }
  185. /*
  186. Should be used first in query chain
  187. */
  188. func WithTransactionContext(tx *gorm.DB) GenericQuery {
  189. return func(_ *gorm.DB) *gorm.DB {
  190. return tx
  191. }
  192. }
  193. func InArray(field string, value []interface{}) GenericQuery {
  194. return func(tx *gorm.DB) *gorm.DB {
  195. return tx.Where(fmt.Sprintf("%s IN ?", field), value)
  196. }
  197. }
  198. func Run(query ...GenericQuery) error {
  199. tmp := DifyPluginDB
  200. for _, q := range query {
  201. tmp = q(tmp)
  202. }
  203. // execute query
  204. return tmp.Error
  205. }
  206. func GetAny[T any](sql string, data ...interface{}) (T /* data */, error) {
  207. var result T
  208. err := DifyPluginDB.Raw(sql, data...).Scan(&result).Error
  209. return result, err
  210. }
  211. func GetOne[T any](query ...GenericQuery) (T /* data */, error) {
  212. var data T
  213. tmp := DifyPluginDB
  214. for _, q := range query {
  215. tmp = q(tmp)
  216. }
  217. err := tmp.First(&data).Error
  218. return data, err
  219. }
  220. func GetAll[T any](query ...GenericQuery) ([]T /* data */, error) {
  221. var data []T
  222. tmp := DifyPluginDB
  223. for _, q := range query {
  224. tmp = q(tmp)
  225. }
  226. err := tmp.Find(&data).Error
  227. return data, err
  228. }
  229. func GetCount[T any](query ...GenericQuery) (int64 /* count */, error) {
  230. var model T
  231. var count int64
  232. tmp := DifyPluginDB
  233. for _, q := range query {
  234. tmp = q(tmp)
  235. }
  236. err := tmp.Model(&model).Count(&count).Error
  237. return count, err
  238. }
  239. func GetSum[T any, R genericComparableConstraint](fields string, query ...GenericQuery) (R, error) {
  240. var model T
  241. var sum R
  242. tmp := DifyPluginDB
  243. for _, q := range query {
  244. tmp = q(tmp)
  245. }
  246. err := tmp.Model(&model).Select(fmt.Sprintf("SUM(%s)", fields)).Scan(&sum).Error
  247. return sum, err
  248. }
  249. func DelAssociation[T any](field string, query ...GenericQuery) error {
  250. var model T
  251. tmp := DifyPluginDB.Model(&model)
  252. for _, q := range query {
  253. tmp = q(tmp)
  254. }
  255. return tmp.Association(field).Unscoped().Clear()
  256. }
  257. func WithTransaction(fn func(tx *gorm.DB) error, ctx ...*gorm.DB) error {
  258. // Start a transaction
  259. db := DifyPluginDB
  260. if len(ctx) > 0 {
  261. db = ctx[0]
  262. }
  263. tx := db.Begin()
  264. if tx.Error != nil {
  265. return tx.Error
  266. }
  267. err := fn(tx)
  268. if err != nil {
  269. if err := tx.Rollback().Error; err != nil {
  270. log.Error("failed to rollback tx: %v", err)
  271. }
  272. return err
  273. }
  274. tx.Commit()
  275. return nil
  276. }
  277. // NOTE: not used in production, only for testing
  278. func DropTable(model any) error {
  279. return DifyPluginDB.Migrator().DropTable(model)
  280. }
  281. // NOTE: not used in production, only for testing
  282. func CreateDatabase(dbname string) error {
  283. return DifyPluginDB.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)).Error
  284. }
  285. // NOTE: not used in production, only for testing
  286. func CreateTable(model any) error {
  287. return DifyPluginDB.Migrator().CreateTable(model)
  288. }