pgsql.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. package db
  2. import (
  3. "fmt"
  4. "strings"
  5. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/clause"
  8. )
  9. /*
  10. ORM for pgsql
  11. */
  12. var DifyPluginDB *gorm.DB
  13. var (
  14. ErrDatabaseNotFound = gorm.ErrRecordNotFound
  15. )
  16. func Create(data any, ctx ...*gorm.DB) error {
  17. if len(ctx) > 0 {
  18. return ctx[0].Create(data).Error
  19. }
  20. return DifyPluginDB.Create(data).Error
  21. }
  22. func Update(data any, ctx ...*gorm.DB) error {
  23. if len(ctx) > 0 {
  24. return ctx[0].Save(data).Error
  25. }
  26. return DifyPluginDB.Save(data).Error
  27. }
  28. func Delete(data any, ctx ...*gorm.DB) error {
  29. if len(ctx) > 0 {
  30. return ctx[0].Delete(data).Error
  31. }
  32. return DifyPluginDB.Delete(data).Error
  33. }
  34. func DeleteByCondition[T any](condition T, ctx ...*gorm.DB) error {
  35. var model T
  36. if len(ctx) > 0 {
  37. return ctx[0].Where(condition).Delete(&model).Error
  38. }
  39. return DifyPluginDB.Where(condition).Delete(&model).Error
  40. }
  41. func ReplaceAssociation[T any, R any](source *T, field string, associations []R, ctx ...*gorm.DB) error {
  42. if len(ctx) > 0 {
  43. return ctx[0].Model(source).Association(field).Replace(associations)
  44. }
  45. return DifyPluginDB.Model(source).Association(field).Replace(associations)
  46. }
  47. func AppendAssociation[T any, R any](source *T, field string, associations R, ctx ...*gorm.DB) error {
  48. if len(ctx) > 0 {
  49. return ctx[0].Model(source).Association(field).Append(associations)
  50. }
  51. return DifyPluginDB.Model(source).Association(field).Append(associations)
  52. }
  53. type genericComparableConstraint interface {
  54. int | int8 | int16 | int32 | int64 |
  55. uint | uint8 | uint16 | uint32 | uint64 |
  56. float32 | float64 |
  57. bool
  58. }
  59. type genericEqualConstraint interface {
  60. genericComparableConstraint | string
  61. }
  62. type GenericQuery func(tx *gorm.DB) *gorm.DB
  63. func Equal[T genericEqualConstraint](field string, value T) GenericQuery {
  64. return func(tx *gorm.DB) *gorm.DB {
  65. return tx.Where(fmt.Sprintf("%s = ?", field), value)
  66. }
  67. }
  68. func EqualOr[T genericEqualConstraint](field string, value T) GenericQuery {
  69. return func(tx *gorm.DB) *gorm.DB {
  70. return tx.Or(fmt.Sprintf("%s = ?", field), value)
  71. }
  72. }
  73. func NotEqual[T genericEqualConstraint](field string, value T) GenericQuery {
  74. return func(tx *gorm.DB) *gorm.DB {
  75. return tx.Where(fmt.Sprintf("%s <> ?", field), value)
  76. }
  77. }
  78. func GreaterThan[T genericComparableConstraint](field string, value T) GenericQuery {
  79. return func(tx *gorm.DB) *gorm.DB {
  80. return tx.Where(fmt.Sprintf("%s > ?", field), value)
  81. }
  82. }
  83. func GreaterThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  84. return func(tx *gorm.DB) *gorm.DB {
  85. return tx.Where(fmt.Sprintf("%s >= ?", field), value)
  86. }
  87. }
  88. func LessThan[T genericComparableConstraint](field string, value T) GenericQuery {
  89. return func(tx *gorm.DB) *gorm.DB {
  90. return tx.Where(fmt.Sprintf("%s < ?", field), value)
  91. }
  92. }
  93. func LessThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
  94. return func(tx *gorm.DB) *gorm.DB {
  95. return tx.Where(fmt.Sprintf("%s <= ?", field), value)
  96. }
  97. }
  98. func Like(field string, value string) GenericQuery {
  99. return func(tx *gorm.DB) *gorm.DB {
  100. return tx.Where(fmt.Sprintf("%s LIKE ?", field), "%"+value+"%")
  101. }
  102. }
  103. func Page(page int, pageSize int) GenericQuery {
  104. return func(tx *gorm.DB) *gorm.DB {
  105. return tx.Offset((page - 1) * pageSize).Limit(pageSize)
  106. }
  107. }
  108. func OrderBy(field string, desc bool) GenericQuery {
  109. return func(tx *gorm.DB) *gorm.DB {
  110. if desc {
  111. return tx.Order(fmt.Sprintf("%s DESC", field))
  112. }
  113. return tx.Order(field)
  114. }
  115. }
  116. // bitwise operation
  117. func WithBit[T genericComparableConstraint](field string, value T) GenericQuery {
  118. return func(tx *gorm.DB) *gorm.DB {
  119. return tx.Where(fmt.Sprintf("%s & ? = ?", field), value, value)
  120. }
  121. }
  122. func WithoutBit[T genericComparableConstraint](field string, value T) GenericQuery {
  123. return func(tx *gorm.DB) *gorm.DB {
  124. return tx.Where(fmt.Sprintf("%s & ~? != 0", field), value)
  125. }
  126. }
  127. func Inc[T genericComparableConstraint](updates map[string]T) GenericQuery {
  128. return func(tx *gorm.DB) *gorm.DB {
  129. expressions := make([]string, 0, len(updates))
  130. values := make([]interface{}, 0, len(updates))
  131. for field, value := range updates {
  132. expressions = append(expressions, fmt.Sprintf("%s = %s + ?", field, field))
  133. values = append(values, value)
  134. }
  135. return tx.UpdateColumns(gorm.Expr(strings.Join(expressions, ", "), values...))
  136. }
  137. }
  138. func Dec[T genericComparableConstraint](updates map[string]T) GenericQuery {
  139. return func(tx *gorm.DB) *gorm.DB {
  140. expressions := make([]string, 0, len(updates))
  141. values := make([]interface{}, 0, len(updates))
  142. for field, value := range updates {
  143. expressions = append(expressions, fmt.Sprintf("%s = %s - ?", field, field))
  144. values = append(values, value)
  145. }
  146. return tx.UpdateColumns(gorm.Expr(strings.Join(expressions, ", "), values...))
  147. }
  148. }
  149. func Model(model any) GenericQuery {
  150. return func(tx *gorm.DB) *gorm.DB {
  151. return tx.Model(model)
  152. }
  153. }
  154. func Fields(fields ...string) GenericQuery {
  155. return func(tx *gorm.DB) *gorm.DB {
  156. return tx.Select(fields)
  157. }
  158. }
  159. func Preload(model string, args ...interface{}) GenericQuery {
  160. return func(tx *gorm.DB) *gorm.DB {
  161. return tx.Preload(model, args...)
  162. }
  163. }
  164. func Join(field string) GenericQuery {
  165. return func(tx *gorm.DB) *gorm.DB {
  166. return tx.Joins(field)
  167. }
  168. }
  169. func WLock /* write lock */ () GenericQuery {
  170. return func(tx *gorm.DB) *gorm.DB {
  171. return tx.Clauses(clause.Locking{Strength: "UPDATE"})
  172. }
  173. }
  174. func Where[T any](model *T) GenericQuery {
  175. return func(tx *gorm.DB) *gorm.DB {
  176. return tx.Where(model)
  177. }
  178. }
  179. func WhereSQL(sql string, args ...interface{}) GenericQuery {
  180. return func(tx *gorm.DB) *gorm.DB {
  181. return tx.Where(sql, args...)
  182. }
  183. }
  184. func Action(fn func(tx *gorm.DB)) GenericQuery {
  185. return func(tx *gorm.DB) *gorm.DB {
  186. fn(tx)
  187. return tx
  188. }
  189. }
  190. /*
  191. Should be used first in query chain
  192. */
  193. func WithTransactionContext(tx *gorm.DB) GenericQuery {
  194. return func(_ *gorm.DB) *gorm.DB {
  195. return tx
  196. }
  197. }
  198. func InArray(field string, value []interface{}) GenericQuery {
  199. return func(tx *gorm.DB) *gorm.DB {
  200. return tx.Where(fmt.Sprintf("%s IN ?", field), value)
  201. }
  202. }
  203. func Run(query ...GenericQuery) error {
  204. tmp := DifyPluginDB
  205. for _, q := range query {
  206. tmp = q(tmp)
  207. }
  208. // execute query
  209. return tmp.Error
  210. }
  211. func GetAny[T any](sql string, data ...interface{}) (T /* data */, error) {
  212. var result T
  213. err := DifyPluginDB.Raw(sql, data...).Scan(&result).Error
  214. return result, err
  215. }
  216. func GetOne[T any](query ...GenericQuery) (T /* data */, error) {
  217. var data T
  218. tmp := DifyPluginDB
  219. for _, q := range query {
  220. tmp = q(tmp)
  221. }
  222. err := tmp.First(&data).Error
  223. return data, err
  224. }
  225. func GetAll[T any](query ...GenericQuery) ([]T /* data */, error) {
  226. var data []T
  227. tmp := DifyPluginDB
  228. for _, q := range query {
  229. tmp = q(tmp)
  230. }
  231. err := tmp.Find(&data).Error
  232. return data, err
  233. }
  234. func GetCount[T any](query ...GenericQuery) (int64 /* count */, error) {
  235. var model T
  236. var count int64
  237. tmp := DifyPluginDB
  238. for _, q := range query {
  239. tmp = q(tmp)
  240. }
  241. err := tmp.Model(&model).Count(&count).Error
  242. return count, err
  243. }
  244. func GetSum[T any, R genericComparableConstraint](fields string, query ...GenericQuery) (R, error) {
  245. var model T
  246. var sum R
  247. tmp := DifyPluginDB
  248. for _, q := range query {
  249. tmp = q(tmp)
  250. }
  251. err := tmp.Model(&model).Select(fmt.Sprintf("SUM(%s)", fields)).Scan(&sum).Error
  252. return sum, err
  253. }
  254. func DelAssociation[T any](field string, query ...GenericQuery) error {
  255. var model T
  256. tmp := DifyPluginDB.Model(&model)
  257. for _, q := range query {
  258. tmp = q(tmp)
  259. }
  260. return tmp.Association(field).Unscoped().Clear()
  261. }
  262. func WithTransaction(fn func(tx *gorm.DB) error, ctx ...*gorm.DB) error {
  263. // Start a transaction
  264. db := DifyPluginDB
  265. if len(ctx) > 0 {
  266. db = ctx[0]
  267. }
  268. tx := db.Begin()
  269. if tx.Error != nil {
  270. return tx.Error
  271. }
  272. err := fn(tx)
  273. if err != nil {
  274. if err := tx.Rollback().Error; err != nil {
  275. log.Error("failed to rollback tx: %v", err)
  276. }
  277. return err
  278. }
  279. tx.Commit()
  280. return nil
  281. }
  282. // NOTE: not used in production, only for testing
  283. func DropTable(model any) error {
  284. return DifyPluginDB.Migrator().DropTable(model)
  285. }
  286. // NOTE: not used in production, only for testing
  287. func CreateDatabase(dbname string) error {
  288. return DifyPluginDB.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)).Error
  289. }
  290. // NOTE: not used in production, only for testing
  291. func CreateTable(model any) error {
  292. return DifyPluginDB.Migrator().CreateTable(model)
  293. }