123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- package db
- import (
- "fmt"
- "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- )
- /*
- ORM for pgsql
- */
- var DifyPluginDB *gorm.DB
- var (
- ErrDatabaseNotFound = gorm.ErrRecordNotFound
- )
- func Create(data any, ctx ...*gorm.DB) error {
- if len(ctx) > 0 {
- return ctx[0].Create(data).Error
- }
- return DifyPluginDB.Create(data).Error
- }
- func Update(data any, ctx ...*gorm.DB) error {
- if len(ctx) > 0 {
- return ctx[0].Save(data).Error
- }
- return DifyPluginDB.Save(data).Error
- }
- func Delete(data any, ctx ...*gorm.DB) error {
- if len(ctx) > 0 {
- return ctx[0].Delete(data).Error
- }
- return DifyPluginDB.Delete(data).Error
- }
- func DeleteByCondition[T any](condition T, ctx ...*gorm.DB) error {
- var model T
- if len(ctx) > 0 {
- return ctx[0].Where(condition).Delete(&model).Error
- }
- return DifyPluginDB.Where(condition).Delete(&model).Error
- }
- func ReplaceAssociation[T any, R any](source *T, field string, associations []R, ctx ...*gorm.DB) error {
- if len(ctx) > 0 {
- return ctx[0].Model(source).Association(field).Replace(associations)
- }
- return DifyPluginDB.Model(source).Association(field).Replace(associations)
- }
- func AppendAssociation[T any, R any](source *T, field string, associations R, ctx ...*gorm.DB) error {
- if len(ctx) > 0 {
- return ctx[0].Model(source).Association(field).Append(associations)
- }
- return DifyPluginDB.Model(source).Association(field).Append(associations)
- }
- type genericComparableConstraint interface {
- int | int8 | int16 | int32 | int64 |
- uint | uint8 | uint16 | uint32 | uint64 |
- float32 | float64 |
- bool
- }
- type genericEqualConstraint interface {
- genericComparableConstraint | string
- }
- type GenericQuery func(tx *gorm.DB) *gorm.DB
- func Equal[T genericEqualConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s = ?", field), value)
- }
- }
- func EqualOr[T genericEqualConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Or(fmt.Sprintf("%s = ?", field), value)
- }
- }
- func NotEqual[T genericEqualConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s <> ?", field), value)
- }
- }
- func GreaterThan[T genericComparableConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s > ?", field), value)
- }
- }
- func GreaterThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s >= ?", field), value)
- }
- }
- func LessThan[T genericComparableConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s < ?", field), value)
- }
- }
- func LessThanOrEqual[T genericComparableConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s <= ?", field), value)
- }
- }
- func Like(field string, value string) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s LIKE ?", field), "%"+value+"%")
- }
- }
- func Page(page int, pageSize int) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Offset((page - 1) * pageSize).Limit(pageSize)
- }
- }
- func OrderBy(field string, desc bool) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- if desc {
- return tx.Order(fmt.Sprintf("%s DESC", field))
- }
- return tx.Order(field)
- }
- }
- // bitwise operation
- func WithBit[T genericComparableConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s & ? = ?", field), value, value)
- }
- }
- func WithoutBit[T genericComparableConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s & ~? != 0", field), value)
- }
- }
- func Inc[T genericComparableConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s + ?", field), value))
- }
- }
- func Dec[T genericComparableConstraint](field string, value T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s - ?", field), value))
- }
- }
- func Model(model any) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Model(model)
- }
- }
- func Fields(fields ...string) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Select(fields)
- }
- }
- func Preload(model string, args ...interface{}) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Preload(model, args...)
- }
- }
- func Join(field string) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Joins(field)
- }
- }
- func WLock /* write lock */ () GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Clauses(clause.Locking{Strength: "UPDATE"})
- }
- }
- func Where[T any](model *T) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(model)
- }
- }
- func WhereSQL(sql string, args ...interface{}) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(sql, args...)
- }
- }
- func Action(fn func(tx *gorm.DB)) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- fn(tx)
- return tx
- }
- }
- /*
- Should be used first in query chain
- */
- func WithTransactionContext(tx *gorm.DB) GenericQuery {
- return func(_ *gorm.DB) *gorm.DB {
- return tx
- }
- }
- func InArray(field string, value []interface{}) GenericQuery {
- return func(tx *gorm.DB) *gorm.DB {
- return tx.Where(fmt.Sprintf("%s IN ?", field), value)
- }
- }
- func Run(query ...GenericQuery) error {
- tmp := DifyPluginDB
- for _, q := range query {
- tmp = q(tmp)
- }
- // execute query
- return tmp.Error
- }
- func GetAny[T any](sql string, data ...interface{}) (T /* data */, error) {
- var result T
- err := DifyPluginDB.Raw(sql, data...).Scan(&result).Error
- return result, err
- }
- func GetOne[T any](query ...GenericQuery) (T /* data */, error) {
- var data T
- tmp := DifyPluginDB
- for _, q := range query {
- tmp = q(tmp)
- }
- err := tmp.First(&data).Error
- return data, err
- }
- func GetAll[T any](query ...GenericQuery) ([]T /* data */, error) {
- var data []T
- tmp := DifyPluginDB
- for _, q := range query {
- tmp = q(tmp)
- }
- err := tmp.Find(&data).Error
- return data, err
- }
- func GetCount[T any](query ...GenericQuery) (int64 /* count */, error) {
- var model T
- var count int64
- tmp := DifyPluginDB
- for _, q := range query {
- tmp = q(tmp)
- }
- err := tmp.Model(&model).Count(&count).Error
- return count, err
- }
- func GetSum[T any, R genericComparableConstraint](fields string, query ...GenericQuery) (R, error) {
- var model T
- var sum R
- tmp := DifyPluginDB
- for _, q := range query {
- tmp = q(tmp)
- }
- err := tmp.Model(&model).Select(fmt.Sprintf("SUM(%s)", fields)).Scan(&sum).Error
- return sum, err
- }
- func DelAssociation[T any](field string, query ...GenericQuery) error {
- var model T
- tmp := DifyPluginDB.Model(&model)
- for _, q := range query {
- tmp = q(tmp)
- }
- return tmp.Association(field).Unscoped().Clear()
- }
- func WithTransaction(fn func(tx *gorm.DB) error, ctx ...*gorm.DB) error {
- // Start a transaction
- db := DifyPluginDB
- if len(ctx) > 0 {
- db = ctx[0]
- }
- tx := db.Begin()
- if tx.Error != nil {
- return tx.Error
- }
- err := fn(tx)
- if err != nil {
- if err := tx.Rollback().Error; err != nil {
- log.Error("failed to rollback tx: %v", err)
- }
- return err
- }
- tx.Commit()
- return nil
- }
- // NOTE: not used in production, only for testing
- func DropTable(model any) error {
- return DifyPluginDB.Migrator().DropTable(model)
- }
- // NOTE: not used in production, only for testing
- func CreateDatabase(dbname string) error {
- return DifyPluginDB.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)).Error
- }
- // NOTE: not used in production, only for testing
- func CreateTable(model any) error {
- return DifyPluginDB.Migrator().CreateTable(model)
- }
|