161 lines
3.5 KiB
Go
161 lines
3.5 KiB
Go
|
|
package repository
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"fmt"
|
|||
|
|
"github.com/casbin/casbin/v2"
|
|||
|
|
"github.com/casbin/casbin/v2/model"
|
|||
|
|
gormadapter "github.com/casbin/gorm-adapter/v3"
|
|||
|
|
"github.com/glebarez/sqlite"
|
|||
|
|
"github.com/redis/go-redis/v9"
|
|||
|
|
"github.com/spf13/viper"
|
|||
|
|
"gorm.io/driver/mysql"
|
|||
|
|
"gorm.io/driver/postgres"
|
|||
|
|
"gorm.io/gorm"
|
|||
|
|
"nunu-layout-admin/pkg/log"
|
|||
|
|
"nunu-layout-admin/pkg/zapgorm2"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
const ctxTxKey = "TxKey"
|
|||
|
|
|
|||
|
|
type Repository struct {
|
|||
|
|
db *gorm.DB
|
|||
|
|
e *casbin.SyncedEnforcer
|
|||
|
|
//rdb *redis.Client
|
|||
|
|
logger *log.Logger
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NewRepository(
|
|||
|
|
logger *log.Logger,
|
|||
|
|
db *gorm.DB,
|
|||
|
|
e *casbin.SyncedEnforcer,
|
|||
|
|
// rdb *redis.Client,
|
|||
|
|
) *Repository {
|
|||
|
|
return &Repository{
|
|||
|
|
db: db,
|
|||
|
|
e: e,
|
|||
|
|
//rdb: rdb,
|
|||
|
|
logger: logger,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type Transaction interface {
|
|||
|
|
Transaction(ctx context.Context, fn func(ctx context.Context) error) error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NewTransaction(r *Repository) Transaction {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DB return tx
|
|||
|
|
// If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
|
|||
|
|
func (r *Repository) DB(ctx context.Context) *gorm.DB {
|
|||
|
|
v := ctx.Value(ctxTxKey)
|
|||
|
|
if v != nil {
|
|||
|
|
if tx, ok := v.(*gorm.DB); ok {
|
|||
|
|
return tx
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return r.db.WithContext(ctx)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
|
|||
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|||
|
|
ctx = context.WithValue(ctx, ctxTxKey, tx)
|
|||
|
|
return fn(ctx)
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
|
|||
|
|
var (
|
|||
|
|
db *gorm.DB
|
|||
|
|
err error
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger := zapgorm2.New(l.Logger)
|
|||
|
|
driver := conf.GetString("data.db.user.driver")
|
|||
|
|
dsn := conf.GetString("data.db.user.dsn")
|
|||
|
|
|
|||
|
|
// GORM doc: https://gorm.io/docs/connecting_to_the_database.html
|
|||
|
|
switch driver {
|
|||
|
|
case "mysql":
|
|||
|
|
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|||
|
|
Logger: logger,
|
|||
|
|
})
|
|||
|
|
case "postgres":
|
|||
|
|
db, err = gorm.Open(postgres.New(postgres.Config{
|
|||
|
|
DSN: dsn,
|
|||
|
|
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
|||
|
|
}), &gorm.Config{})
|
|||
|
|
case "sqlite":
|
|||
|
|
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
|||
|
|
default:
|
|||
|
|
panic("unknown db driver")
|
|||
|
|
}
|
|||
|
|
if err != nil {
|
|||
|
|
panic(err)
|
|||
|
|
}
|
|||
|
|
db = db.Debug()
|
|||
|
|
|
|||
|
|
// Connection Pool config
|
|||
|
|
sqlDB, err := db.DB()
|
|||
|
|
if err != nil {
|
|||
|
|
panic(err)
|
|||
|
|
}
|
|||
|
|
sqlDB.SetMaxIdleConns(10)
|
|||
|
|
sqlDB.SetMaxOpenConns(100)
|
|||
|
|
sqlDB.SetConnMaxLifetime(time.Hour)
|
|||
|
|
return db
|
|||
|
|
}
|
|||
|
|
func NewCasbinEnforcer(conf *viper.Viper, l *log.Logger, db *gorm.DB) *casbin.SyncedEnforcer {
|
|||
|
|
a, _ := gormadapter.NewAdapterByDB(db)
|
|||
|
|
m, err := model.NewModelFromString(`
|
|||
|
|
[request_definition]
|
|||
|
|
r = sub, obj, act
|
|||
|
|
|
|||
|
|
[policy_definition]
|
|||
|
|
p = sub, obj, act
|
|||
|
|
|
|||
|
|
[role_definition]
|
|||
|
|
g = _, _
|
|||
|
|
|
|||
|
|
[policy_effect]
|
|||
|
|
e = some(where (p.eft == allow))
|
|||
|
|
|
|||
|
|
[matchers]
|
|||
|
|
m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
|
|||
|
|
`)
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
panic(err)
|
|||
|
|
}
|
|||
|
|
e, _ := casbin.NewSyncedEnforcer(m, a)
|
|||
|
|
e.StartAutoLoadPolicy(10 * time.Second) // 每10秒自动加载策略,防止启动多服务进程策略不一致
|
|||
|
|
|
|||
|
|
// Enable Logger, decide whether to show it in terminal
|
|||
|
|
//e.EnableLog(true)
|
|||
|
|
|
|||
|
|
// Save the policy back to DB.
|
|||
|
|
e.EnableAutoSave(true)
|
|||
|
|
|
|||
|
|
return e
|
|||
|
|
}
|
|||
|
|
func NewRedis(conf *viper.Viper) *redis.Client {
|
|||
|
|
rdb := redis.NewClient(&redis.Options{
|
|||
|
|
Addr: conf.GetString("data.redis.addr"),
|
|||
|
|
Password: conf.GetString("data.redis.password"),
|
|||
|
|
DB: conf.GetInt("data.redis.db"),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|||
|
|
defer cancel()
|
|||
|
|
|
|||
|
|
_, err := rdb.Ping(ctx).Result()
|
|||
|
|
if err != nil {
|
|||
|
|
panic(fmt.Sprintf("redis error: %s", err.Error()))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return rdb
|
|||
|
|
}
|