pgsql.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. package db
  2. import (
  3. "fmt"
  4. "strings"
  5. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/clause"
  8. )
  9. /*
  10. ORM for pgsql
  11. */
  12. var DifyPluginDB *gorm.DB
  13. var (
  14. ErrDatabaseNotFound = gorm.ErrRecordNotFound
  15. )
  16. func Create(data any, ctx ...*gorm.DB) error {
  17. if len(ctx) > 0 {
  18. return ctx[0].Create(data).Error
  19. }
  20. return DifyPluginDB.Create(data).Error
  21. }
  22. func Update(data any, ctx ...*gorm.DB) error {
  23. if len(ctx) > 0 {
  24. return ctx[0].Save(data).Error
  25. }
  26. return DifyPluginDB.Save(data).Error
  27. }
  28. func Delete(data any, ctx ...*gorm.DB) error {
  29. if len(ctx) > 0 {
  30. return ctx[0].Delete(data).Error
  31. }
  32. return DifyPluginDB.Delete(data).Error
  33. }
  34. func DeleteByCondition[T any](condition T, ctx ...*gorm.DB) error {
  35. var model T
  36. if len(ctx) > 0 {
  37. return ctx[0].Where(condition).Delete(&model).Error
  38. }
  39. return DifyPluginDB.Where(condition).Delete(&model).Error
  40. }
  41. func ReplaceAssociation[T any, R any](source *T, field string, associations []R, ctx ...*gorm.DB) error {
  42. if len(ctx) > 0 {
  43. return ctx[0].Model(source).Association(field).Replace(associations)
  44. }
  45. return DifyPluginDB.Model(source).Association(field).Replace(associations)
  46. }
  47. func AppendAssociation[T any, R any](source *T, field string, associations R, ctx ...*gorm.DB) error {
  48. if len(ctx) > 0 {
  49. return ctx[0].Model(source).Association(field).Append(associations)
  50. }
  51. return DifyPluginDB.Model(source).Association(field).Append(associations)
  52. }
  53. type genericComparableConstraint interface {
  54. int | int8 | int16 | int32 | int64 |
  55. uint | uint8 | uint16 | uint32 | uint64 |
  56. float32 | float64 |
  57. bool
  58. }
  59. type genericEqualConstraint interface {
  60. genericComparableConstraint | string
  61. }
  62. type GenericQuery func(tx *gorm.DB) *gorm.DB
  63. func Equal[T genericEqualConstraint](field string, value T) GenericQuery {
  64. return func(tx *gorm.DB) *gorm.DB {
  65. return tx.Where(fmt.Sprintf("%s = ?", field), value)
  66. }
  67. }
  68. func EqualOr[T genericEqualConstraint](field string, value T) GenericQuery {
  69. return func(tx *gorm.DB) *gorm.DB {
  70. return tx.Or(fmt.Sprintf("%s = ?", field), value)
  71. }
  72. }
  73. func NotEqual[T genericEqualConstraint](field string, value T) GenericQuery {
  74. return func(tx *gorm.DB) *gorm.DB {
  75. return tx.Where(fmt.Sprintf("%s <> ?", field), value)
  76. }
  77. }
  78. func GreaterThan[T genericComparableConstraint](field string, value T) GenericQuery {
  79. return func(tx *gorm.DB) *gorm.DB {
  80. return tx.Where(fmt.Sprintf("%s > ?", field), value)
  81. }
  82. }
  83. func GreaterThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  84. return func(tx *gorm.DB) *gorm.DB {
  85. return tx.Where(fmt.Sprintf("%s >= ?", field), value)
  86. }
  87. }
  88. func LessThan[T genericComparableConstraint](field string, value T) GenericQuery {
  89. return func(tx *gorm.DB) *gorm.DB {
  90. return tx.Where(fmt.Sprintf("%s < ?", field), value)
  91. }
  92. }
  93. func LessThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  94. return func(tx *gorm.DB) *gorm.DB {
  95. return tx.Where(fmt.Sprintf("%s <= ?", field), value)
  96. }
  97. }
  98. func Like(field string, value string) GenericQuery {
  99. return func(tx *gorm.DB) *gorm.DB {
  100. return tx.Where(fmt.Sprintf("%s LIKE ?", field), "%"+value+"%")
  101. }
  102. }
  103. func Page(page int, pageSize int) GenericQuery {
  104. return func(tx *gorm.DB) *gorm.DB {
  105. return tx.Offset((page - 1) * pageSize).Limit(pageSize)
  106. }
  107. }
  108. func OrderBy(field string, desc bool) GenericQuery {
  109. return func(tx *gorm.DB) *gorm.DB {
  110. if desc {
  111. return tx.Order(fmt.Sprintf("%s DESC", field))
  112. }
  113. return tx.Order(field)
  114. }
  115. }
  116. // bitwise operation
  117. func WithBit[T genericComparableConstraint](field string, value T) GenericQuery {
  118. return func(tx *gorm.DB) *gorm.DB {
  119. return tx.Where(fmt.Sprintf("%s & ? = ?", field), value, value)
  120. }
  121. }
  122. func WithoutBit[T genericComparableConstraint](field string, value T) GenericQuery {
  123. return func(tx *gorm.DB) *gorm.DB {
  124. return tx.Where(fmt.Sprintf("%s & ~? != 0", field), value)
  125. }
  126. }
  127. func Inc[T genericComparableConstraint](updates map[string]T) GenericQuery {
  128. return func(tx *gorm.DB) *gorm.DB {
  129. m := make(map[string]any)
  130. for field, value := range updates {
  131. m[field] = gorm.Expr(fmt.Sprintf("%s + ?", field), value)
  132. }
  133. return tx.UpdateColumns(m)
  134. }
  135. }
  136. func Dec[T genericComparableConstraint](updates map[string]T) GenericQuery {
  137. return func(tx *gorm.DB) *gorm.DB {
  138. expressions := make([]string, 0, len(updates))
  139. values := make([]interface{}, 0, len(updates))
  140. for field, value := range updates {
  141. expressions = append(expressions, fmt.Sprintf("%s = %s - ?", field, field))
  142. values = append(values, value)
  143. }
  144. return tx.UpdateColumns(gorm.Expr(strings.Join(expressions, ", "), values...))
  145. }
  146. }
  147. func Model(model any) GenericQuery {
  148. return func(tx *gorm.DB) *gorm.DB {
  149. return tx.Model(model)
  150. }
  151. }
  152. func Fields(fields ...string) GenericQuery {
  153. return func(tx *gorm.DB) *gorm.DB {
  154. return tx.Select(fields)
  155. }
  156. }
  157. func Preload(model string, args ...interface{}) GenericQuery {
  158. return func(tx *gorm.DB) *gorm.DB {
  159. return tx.Preload(model, args...)
  160. }
  161. }
  162. func Join(field string) GenericQuery {
  163. return func(tx *gorm.DB) *gorm.DB {
  164. return tx.Joins(field)
  165. }
  166. }
  167. func WLock /* write lock */ () GenericQuery {
  168. return func(tx *gorm.DB) *gorm.DB {
  169. return tx.Clauses(clause.Locking{Strength: "UPDATE"})
  170. }
  171. }
  172. func Where[T any](model *T) GenericQuery {
  173. return func(tx *gorm.DB) *gorm.DB {
  174. return tx.Where(model)
  175. }
  176. }
  177. func WhereSQL(sql string, args ...interface{}) GenericQuery {
  178. return func(tx *gorm.DB) *gorm.DB {
  179. return tx.Where(sql, args...)
  180. }
  181. }
  182. func Action(fn func(tx *gorm.DB)) GenericQuery {
  183. return func(tx *gorm.DB) *gorm.DB {
  184. fn(tx)
  185. return tx
  186. }
  187. }
  188. /*
  189. Should be used first in query chain
  190. */
  191. func WithTransactionContext(tx *gorm.DB) GenericQuery {
  192. return func(_ *gorm.DB) *gorm.DB {
  193. return tx
  194. }
  195. }
  196. func InArray(field string, value []interface{}) GenericQuery {
  197. return func(tx *gorm.DB) *gorm.DB {
  198. return tx.Where(fmt.Sprintf("%s IN ?", field), value)
  199. }
  200. }
  201. func Run(query ...GenericQuery) error {
  202. tmp := DifyPluginDB
  203. for _, q := range query {
  204. tmp = q(tmp)
  205. }
  206. // execute query
  207. return tmp.Error
  208. }
  209. func GetAny[T any](sql string, data ...interface{}) (T /* data */, error) {
  210. var result T
  211. err := DifyPluginDB.Raw(sql, data...).Scan(&result).Error
  212. return result, err
  213. }
  214. func GetOne[T any](query ...GenericQuery) (T /* data */, error) {
  215. var data T
  216. tmp := DifyPluginDB
  217. for _, q := range query {
  218. tmp = q(tmp)
  219. }
  220. err := tmp.First(&data).Error
  221. return data, err
  222. }
  223. func GetAll[T any](query ...GenericQuery) ([]T /* data */, error) {
  224. var data []T
  225. tmp := DifyPluginDB
  226. for _, q := range query {
  227. tmp = q(tmp)
  228. }
  229. err := tmp.Find(&data).Error
  230. return data, err
  231. }
  232. func GetCount[T any](query ...GenericQuery) (int64 /* count */, error) {
  233. var model T
  234. var count int64
  235. tmp := DifyPluginDB
  236. for _, q := range query {
  237. tmp = q(tmp)
  238. }
  239. err := tmp.Model(&model).Count(&count).Error
  240. return count, err
  241. }
  242. func GetSum[T any, R genericComparableConstraint](fields string, query ...GenericQuery) (R, error) {
  243. var model T
  244. var sum R
  245. tmp := DifyPluginDB
  246. for _, q := range query {
  247. tmp = q(tmp)
  248. }
  249. err := tmp.Model(&model).Select(fmt.Sprintf("SUM(%s)", fields)).Scan(&sum).Error
  250. return sum, err
  251. }
  252. func DelAssociation[T any](field string, query ...GenericQuery) error {
  253. var model T
  254. tmp := DifyPluginDB.Model(&model)
  255. for _, q := range query {
  256. tmp = q(tmp)
  257. }
  258. return tmp.Association(field).Unscoped().Clear()
  259. }
  260. func WithTransaction(fn func(tx *gorm.DB) error, ctx ...*gorm.DB) error {
  261. // Start a transaction
  262. db := DifyPluginDB
  263. if len(ctx) > 0 {
  264. db = ctx[0]
  265. }
  266. tx := db.Begin()
  267. if tx.Error != nil {
  268. return tx.Error
  269. }
  270. err := fn(tx)
  271. if err != nil {
  272. if err := tx.Rollback().Error; err != nil {
  273. log.Error("failed to rollback tx: %v", err)
  274. }
  275. return err
  276. }
  277. tx.Commit()
  278. return nil
  279. }
  280. // NOTE: not used in production, only for testing
  281. func DropTable(model any) error {
  282. return DifyPluginDB.Migrator().DropTable(model)
  283. }
  284. // NOTE: not used in production, only for testing
  285. func CreateDatabase(dbname string) error {
  286. return DifyPluginDB.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)).Error
  287. }
  288. // NOTE: not used in production, only for testing
  289. func CreateTable(model any) error {
  290. return DifyPluginDB.Migrator().CreateTable(model)
  291. }