161 lines
3.5 KiB
Go
161 lines
3.5 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"github.com/casbin/casbin/v2"
|
||
"github.com/casbin/casbin/v2/model"
|
||
gormadapter "github.com/casbin/gorm-adapter/v3"
|
||
"github.com/glebarez/sqlite"
|
||
"github.com/redis/go-redis/v9"
|
||
"github.com/spf13/viper"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/gorm"
|
||
"nunu-layout-admin/pkg/log"
|
||
"nunu-layout-admin/pkg/zapgorm2"
|
||
"time"
|
||
)
|
||
|
||
const ctxTxKey = "TxKey"
|
||
|
||
type Repository struct {
|
||
db *gorm.DB
|
||
e *casbin.SyncedEnforcer
|
||
//rdb *redis.Client
|
||
logger *log.Logger
|
||
}
|
||
|
||
func NewRepository(
|
||
logger *log.Logger,
|
||
db *gorm.DB,
|
||
e *casbin.SyncedEnforcer,
|
||
// rdb *redis.Client,
|
||
) *Repository {
|
||
return &Repository{
|
||
db: db,
|
||
e: e,
|
||
//rdb: rdb,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
type Transaction interface {
|
||
Transaction(ctx context.Context, fn func(ctx context.Context) error) error
|
||
}
|
||
|
||
func NewTransaction(r *Repository) Transaction {
|
||
return r
|
||
}
|
||
|
||
// DB return tx
|
||
// If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
|
||
func (r *Repository) DB(ctx context.Context) *gorm.DB {
|
||
v := ctx.Value(ctxTxKey)
|
||
if v != nil {
|
||
if tx, ok := v.(*gorm.DB); ok {
|
||
return tx
|
||
}
|
||
}
|
||
return r.db.WithContext(ctx)
|
||
}
|
||
|
||
func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
|
||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
ctx = context.WithValue(ctx, ctxTxKey, tx)
|
||
return fn(ctx)
|
||
})
|
||
}
|
||
|
||
func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
|
||
var (
|
||
db *gorm.DB
|
||
err error
|
||
)
|
||
|
||
logger := zapgorm2.New(l.Logger)
|
||
driver := conf.GetString("data.db.user.driver")
|
||
dsn := conf.GetString("data.db.user.dsn")
|
||
|
||
// GORM doc: https://gorm.io/docs/connecting_to_the_database.html
|
||
switch driver {
|
||
case "mysql":
|
||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||
Logger: logger,
|
||
})
|
||
case "postgres":
|
||
db, err = gorm.Open(postgres.New(postgres.Config{
|
||
DSN: dsn,
|
||
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
||
}), &gorm.Config{})
|
||
case "sqlite":
|
||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||
default:
|
||
panic("unknown db driver")
|
||
}
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
db = db.Debug()
|
||
|
||
// Connection Pool config
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
sqlDB.SetMaxIdleConns(10)
|
||
sqlDB.SetMaxOpenConns(100)
|
||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||
return db
|
||
}
|
||
func NewCasbinEnforcer(conf *viper.Viper, l *log.Logger, db *gorm.DB) *casbin.SyncedEnforcer {
|
||
a, _ := gormadapter.NewAdapterByDB(db)
|
||
m, err := model.NewModelFromString(`
|
||
[request_definition]
|
||
r = sub, obj, act
|
||
|
||
[policy_definition]
|
||
p = sub, obj, act
|
||
|
||
[role_definition]
|
||
g = _, _
|
||
|
||
[policy_effect]
|
||
e = some(where (p.eft == allow))
|
||
|
||
[matchers]
|
||
m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
|
||
`)
|
||
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
e, _ := casbin.NewSyncedEnforcer(m, a)
|
||
e.StartAutoLoadPolicy(10 * time.Second) // 每10秒自动加载策略,防止启动多服务进程策略不一致
|
||
|
||
// Enable Logger, decide whether to show it in terminal
|
||
//e.EnableLog(true)
|
||
|
||
// Save the policy back to DB.
|
||
e.EnableAutoSave(true)
|
||
|
||
return e
|
||
}
|
||
func NewRedis(conf *viper.Viper) *redis.Client {
|
||
rdb := redis.NewClient(&redis.Options{
|
||
Addr: conf.GetString("data.redis.addr"),
|
||
Password: conf.GetString("data.redis.password"),
|
||
DB: conf.GetInt("data.redis.db"),
|
||
})
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
_, err := rdb.Ping(ctx).Result()
|
||
if err != nil {
|
||
panic(fmt.Sprintf("redis error: %s", err.Error()))
|
||
}
|
||
|
||
return rdb
|
||
}
|