pgsql.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. package db
  2. import (
  3. "fmt"
  4. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. )
  8. /*
  9. ORM for pgsql
  10. */
  11. var DifyPluginDB *gorm.DB
  12. var (
  13. ErrDatabaseNotFound = gorm.ErrRecordNotFound
  14. )
  15. func Create(data any, ctx ...*gorm.DB) error {
  16. if len(ctx) > 0 {
  17. return ctx[0].Create(data).Error
  18. }
  19. return DifyPluginDB.Create(data).Error
  20. }
  21. func Update(data any, ctx ...*gorm.DB) error {
  22. if len(ctx) > 0 {
  23. return ctx[0].Save(data).Error
  24. }
  25. return DifyPluginDB.Save(data).Error
  26. }
  27. func Delete(data any, ctx ...*gorm.DB) error {
  28. if len(ctx) > 0 {
  29. return ctx[0].Delete(data).Error
  30. }
  31. return DifyPluginDB.Delete(data).Error
  32. }
  33. func DeleteByCondition[T any](condition T, ctx ...*gorm.DB) error {
  34. var model T
  35. if len(ctx) > 0 {
  36. return ctx[0].Where(condition).Delete(&model).Error
  37. }
  38. return DifyPluginDB.Where(condition).Delete(&model).Error
  39. }
  40. func ReplaceAssociation[T any, R any](source *T, field string, associations []R, ctx ...*gorm.DB) error {
  41. if len(ctx) > 0 {
  42. return ctx[0].Model(source).Association(field).Replace(associations)
  43. }
  44. return DifyPluginDB.Model(source).Association(field).Replace(associations)
  45. }
  46. func AppendAssociation[T any, R any](source *T, field string, associations R, ctx ...*gorm.DB) error {
  47. if len(ctx) > 0 {
  48. return ctx[0].Model(source).Association(field).Append(associations)
  49. }
  50. return DifyPluginDB.Model(source).Association(field).Append(associations)
  51. }
  52. type genericComparableConstraint interface {
  53. int | int8 | int16 | int32 | int64 |
  54. uint | uint8 | uint16 | uint32 | uint64 |
  55. float32 | float64 |
  56. bool
  57. }
  58. type genericEqualConstraint interface {
  59. genericComparableConstraint | string
  60. }
  61. type GenericQuery func(tx *gorm.DB) *gorm.DB
  62. func Equal[T genericEqualConstraint](field string, value T) GenericQuery {
  63. return func(tx *gorm.DB) *gorm.DB {
  64. return tx.Where(fmt.Sprintf("%s = ?", field), value)
  65. }
  66. }
  67. func EqualOr[T genericEqualConstraint](field string, value T) GenericQuery {
  68. return func(tx *gorm.DB) *gorm.DB {
  69. return tx.Or(fmt.Sprintf("%s = ?", field), value)
  70. }
  71. }
  72. func NotEqual[T genericEqualConstraint](field string, value T) GenericQuery {
  73. return func(tx *gorm.DB) *gorm.DB {
  74. return tx.Where(fmt.Sprintf("%s <> ?", field), value)
  75. }
  76. }
  77. func GreaterThan[T genericComparableConstraint](field string, value T) GenericQuery {
  78. return func(tx *gorm.DB) *gorm.DB {
  79. return tx.Where(fmt.Sprintf("%s > ?", field), value)
  80. }
  81. }
  82. func GreaterThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  83. return func(tx *gorm.DB) *gorm.DB {
  84. return tx.Where(fmt.Sprintf("%s >= ?", field), value)
  85. }
  86. }
  87. func LessThan[T genericComparableConstraint](field string, value T) GenericQuery {
  88. return func(tx *gorm.DB) *gorm.DB {
  89. return tx.Where(fmt.Sprintf("%s < ?", field), value)
  90. }
  91. }
  92. func LessThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  93. return func(tx *gorm.DB) *gorm.DB {
  94. return tx.Where(fmt.Sprintf("%s <= ?", field), value)
  95. }
  96. }
  97. func Like(field string, value string) GenericQuery {
  98. return func(tx *gorm.DB) *gorm.DB {
  99. return tx.Where(fmt.Sprintf("%s LIKE ?", field), "%"+value+"%")
  100. }
  101. }
  102. func Page(page int, pageSize int) GenericQuery {
  103. return func(tx *gorm.DB) *gorm.DB {
  104. return tx.Offset((page - 1) * pageSize).Limit(pageSize)
  105. }
  106. }
  107. func OrderBy(field string, desc bool) GenericQuery {
  108. return func(tx *gorm.DB) *gorm.DB {
  109. if desc {
  110. return tx.Order(fmt.Sprintf("%s DESC", field))
  111. }
  112. return tx.Order(field)
  113. }
  114. }
  115. // bitwise operation
  116. func WithBit[T genericComparableConstraint](field string, value T) GenericQuery {
  117. return func(tx *gorm.DB) *gorm.DB {
  118. return tx.Where(fmt.Sprintf("%s & ? = ?", field), value, value)
  119. }
  120. }
  121. func WithoutBit[T genericComparableConstraint](field string, value T) GenericQuery {
  122. return func(tx *gorm.DB) *gorm.DB {
  123. return tx.Where(fmt.Sprintf("%s & ~? != 0", field), value)
  124. }
  125. }
  126. func Inc[T genericComparableConstraint](field string, value T) GenericQuery {
  127. return func(tx *gorm.DB) *gorm.DB {
  128. return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s + ?", field), value))
  129. }
  130. }
  131. func Dec[T genericComparableConstraint](field string, value T) GenericQuery {
  132. return func(tx *gorm.DB) *gorm.DB {
  133. return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s - ?", field), value))
  134. }
  135. }
  136. func Model(model any) GenericQuery {
  137. return func(tx *gorm.DB) *gorm.DB {
  138. return tx.Model(model)
  139. }
  140. }
  141. func Fields(fields ...string) GenericQuery {
  142. return func(tx *gorm.DB) *gorm.DB {
  143. return tx.Select(fields)
  144. }
  145. }
  146. func Preload(model string, args ...interface{}) GenericQuery {
  147. return func(tx *gorm.DB) *gorm.DB {
  148. return tx.Preload(model, args...)
  149. }
  150. }
  151. func Join(field string) GenericQuery {
  152. return func(tx *gorm.DB) *gorm.DB {
  153. return tx.Joins(field)
  154. }
  155. }
  156. func WLock /* write lock */ () GenericQuery {
  157. return func(tx *gorm.DB) *gorm.DB {
  158. return tx.Clauses(clause.Locking{Strength: "UPDATE"})
  159. }
  160. }
  161. func Where[T any](model *T) GenericQuery {
  162. return func(tx *gorm.DB) *gorm.DB {
  163. return tx.Where(model)
  164. }
  165. }
  166. func WhereSQL(sql string, args ...interface{}) GenericQuery {
  167. return func(tx *gorm.DB) *gorm.DB {
  168. return tx.Where(sql, args...)
  169. }
  170. }
  171. func Action(fn func(tx *gorm.DB)) GenericQuery {
  172. return func(tx *gorm.DB) *gorm.DB {
  173. fn(tx)
  174. return tx
  175. }
  176. }
  177. /*
  178. Should be used first in query chain
  179. */
  180. func WithTransactionContext(tx *gorm.DB) GenericQuery {
  181. return func(_ *gorm.DB) *gorm.DB {
  182. return tx
  183. }
  184. }
  185. func InArray(field string, value []interface{}) GenericQuery {
  186. return func(tx *gorm.DB) *gorm.DB {
  187. return tx.Where(fmt.Sprintf("%s IN ?", field), value)
  188. }
  189. }
  190. func Run(query ...GenericQuery) error {
  191. tmp := DifyPluginDB
  192. for _, q := range query {
  193. tmp = q(tmp)
  194. }
  195. // execute query
  196. return tmp.Error
  197. }
  198. func GetAny[T any](sql string, data ...interface{}) (T /* data */, error) {
  199. var result T
  200. err := DifyPluginDB.Raw(sql, data...).Scan(&result).Error
  201. return result, err
  202. }
  203. func GetOne[T any](query ...GenericQuery) (T /* data */, error) {
  204. var data T
  205. tmp := DifyPluginDB
  206. for _, q := range query {
  207. tmp = q(tmp)
  208. }
  209. err := tmp.First(&data).Error
  210. return data, err
  211. }
  212. func GetAll[T any](query ...GenericQuery) ([]T /* data */, error) {
  213. var data []T
  214. tmp := DifyPluginDB
  215. for _, q := range query {
  216. tmp = q(tmp)
  217. }
  218. err := tmp.Find(&data).Error
  219. return data, err
  220. }
  221. func GetCount[T any](query ...GenericQuery) (int64 /* count */, error) {
  222. var model T
  223. var count int64
  224. tmp := DifyPluginDB
  225. for _, q := range query {
  226. tmp = q(tmp)
  227. }
  228. err := tmp.Model(&model).Count(&count).Error
  229. return count, err
  230. }
  231. func GetSum[T any, R genericComparableConstraint](fields string, query ...GenericQuery) (R, error) {
  232. var model T
  233. var sum R
  234. tmp := DifyPluginDB
  235. for _, q := range query {
  236. tmp = q(tmp)
  237. }
  238. err := tmp.Model(&model).Select(fmt.Sprintf("SUM(%s)", fields)).Scan(&sum).Error
  239. return sum, err
  240. }
  241. func DelAssociation[T any](field string, query ...GenericQuery) error {
  242. var model T
  243. tmp := DifyPluginDB.Model(&model)
  244. for _, q := range query {
  245. tmp = q(tmp)
  246. }
  247. return tmp.Association(field).Unscoped().Clear()
  248. }
  249. func WithTransaction(fn func(tx *gorm.DB) error, ctx ...*gorm.DB) error {
  250. // Start a transaction
  251. db := DifyPluginDB
  252. if len(ctx) > 0 {
  253. db = ctx[0]
  254. }
  255. tx := db.Begin()
  256. if tx.Error != nil {
  257. return tx.Error
  258. }
  259. err := fn(tx)
  260. if err != nil {
  261. if err := tx.Rollback().Error; err != nil {
  262. log.Error("failed to rollback tx: %v", err)
  263. }
  264. return err
  265. }
  266. tx.Commit()
  267. return nil
  268. }
  269. // NOTE: not used in production, only for testing
  270. func DropTable(model any) error {
  271. return DifyPluginDB.Migrator().DropTable(model)
  272. }
  273. // NOTE: not used in production, only for testing
  274. func CreateDatabase(dbname string) error {
  275. return DifyPluginDB.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)).Error
  276. }
  277. // NOTE: not used in production, only for testing
  278. func CreateTable(model any) error {
  279. return DifyPluginDB.Migrator().CreateTable(model)
  280. }