nunu-layout-admin/internal/repository/repository.go

161 lines
3.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"fmt"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
gormadapter "github.com/casbin/gorm-adapter/v3"
"github.com/glebarez/sqlite"
"github.com/redis/go-redis/v9"
"github.com/spf13/viper"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"nunu-layout-admin/pkg/log"
"nunu-layout-admin/pkg/zapgorm2"
"time"
)
const ctxTxKey = "TxKey"
type Repository struct {
db *gorm.DB
e *casbin.SyncedEnforcer
//rdb *redis.Client
logger *log.Logger
}
func NewRepository(
logger *log.Logger,
db *gorm.DB,
e *casbin.SyncedEnforcer,
// rdb *redis.Client,
) *Repository {
return &Repository{
db: db,
e: e,
//rdb: rdb,
logger: logger,
}
}
type Transaction interface {
Transaction(ctx context.Context, fn func(ctx context.Context) error) error
}
func NewTransaction(r *Repository) Transaction {
return r
}
// DB return tx
// If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
func (r *Repository) DB(ctx context.Context) *gorm.DB {
v := ctx.Value(ctxTxKey)
if v != nil {
if tx, ok := v.(*gorm.DB); ok {
return tx
}
}
return r.db.WithContext(ctx)
}
func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
ctx = context.WithValue(ctx, ctxTxKey, tx)
return fn(ctx)
})
}
func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
var (
db *gorm.DB
err error
)
logger := zapgorm2.New(l.Logger)
driver := conf.GetString("data.db.user.driver")
dsn := conf.GetString("data.db.user.dsn")
// GORM doc: https://gorm.io/docs/connecting_to_the_database.html
switch driver {
case "mysql":
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger,
})
case "postgres":
db, err = gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
default:
panic("unknown db driver")
}
if err != nil {
panic(err)
}
db = db.Debug()
// Connection Pool config
sqlDB, err := db.DB()
if err != nil {
panic(err)
}
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
sqlDB.SetConnMaxLifetime(time.Hour)
return db
}
func NewCasbinEnforcer(conf *viper.Viper, l *log.Logger, db *gorm.DB) *casbin.SyncedEnforcer {
a, _ := gormadapter.NewAdapterByDB(db)
m, err := model.NewModelFromString(`
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[role_definition]
g = _, _
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
`)
if err != nil {
panic(err)
}
e, _ := casbin.NewSyncedEnforcer(m, a)
e.StartAutoLoadPolicy(10 * time.Second) // 每10秒自动加载策略防止启动多服务进程策略不一致
// Enable Logger, decide whether to show it in terminal
//e.EnableLog(true)
// Save the policy back to DB.
e.EnableAutoSave(true)
return e
}
func NewRedis(conf *viper.Viper) *redis.Client {
rdb := redis.NewClient(&redis.Options{
Addr: conf.GetString("data.redis.addr"),
Password: conf.GetString("data.redis.password"),
DB: conf.GetInt("data.redis.db"),
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := rdb.Ping(ctx).Result()
if err != nil {
panic(fmt.Sprintf("redis error: %s", err.Error()))
}
return rdb
}