This commit is contained in:
suguo 2026-03-12 17:28:19 +08:00
commit 0e9533feac
24 changed files with 1256 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
go.sum
heritage-api
logs
*.exe
.DS_Store
__debug_bin*

23
Dockerfile Normal file
View File

@ -0,0 +1,23 @@
FROM harbor.ks.easyj.top/zt/alpine:0.1
ENV APP_DIR=/app \
MYSQL_DSN=root:SG1231@tcp(mysql:3306)/heritage?charset=utf8mb4&parseTime=True&loc=Local \
MYSQL_MAXLIFETIME=1 \
MYSQL_MAXIDLECONNS=2 \
MYSQL_MAXOPENCONNS=50 \
MYSQL_INIT=true \
GIN_MODE=release \
REDIS_DSN=redis:6379 \
REDIS_DB=1 \
REDIS_PWD=eYVX7EwVmmxKPCDmwMtyKVge8oLd2t81 \
LOGLEVEL=debug
COPY heritage ${APP_DIR}/heritage
WORKDIR ${APP_DIR}
RUN chmod +x heritage
EXPOSE 8080
CMD ["./heritage"]

116
gin/filter-gin.go Normal file
View File

@ -0,0 +1,116 @@
package gin
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"myschools.me/heritage/heritage-api/model"
"myschools.me/heritage/heritage-api/redis"
"myschools.me/heritage/heritage-api/service"
)
func auth() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
token = c.Query("Authorization")
}
token = strings.TrimSpace(token)
token = strings.TrimPrefix(token, "Bearer ")
token = strings.TrimPrefix(token, "bearer ")
staff, err := redis.UserTokenGet(&token)
if err != nil {
logrus.WithFields(logrus.Fields{
"func": "authUser",
}).Errorf("redis.UserTokenGet: %s", err.Error())
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"data": "无效TOKEN, 请重新登录!",
})
return
}
if staff == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"data": "无效TOKEN, 请重新登录!",
})
return
}
c.Set("user", staff)
c.Next()
}
}
func authorize() gin.HandlerFunc {
return func(c *gin.Context) {
currentUser, ok := c.Get("user")
if !ok {
c.Next()
return
}
u, ok := currentUser.(*model.User)
if !ok || u == nil || u.ID == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"data": "无效TOKEN, 请重新登录!",
})
return
}
fullPath := c.FullPath()
if fullPath == "" {
c.Next()
return
}
permissionCode := c.Request.Method + ":" + fullPath
defined, err := service.PermissionDefined(permissionCode)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"data": "权限校验失败",
})
return
}
if !defined {
c.Next()
return
}
roleID, found, err := service.UserRoleIDByUserID(u.ID)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"data": "权限校验失败",
})
return
}
if !found {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"data": "无效TOKEN, 请重新登录!",
})
return
}
if roleID == nil || *roleID == "" {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"data": "无权限访问",
})
return
}
allowed, err := service.RoleHasPermission(roleID, permissionCode)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"data": "权限校验失败",
})
return
}
if !allowed {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"data": "无权限访问",
})
return
}
c.Next()
}
}

57
gin/gin.go Normal file
View File

@ -0,0 +1,57 @@
package gin
import (
"fmt"
"io"
"log"
"net/http"
"os"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
var (
addr string
port int
ssl string
sslPem string
sslKey string
)
func Service() {
addr = os.Getenv("GIN_ADDR")
if addr == "" {
addr = "0.0.0.0"
}
port, _ = strconv.Atoi(os.Getenv("GIN_PORT"))
if port == 0 {
port = 8080
}
go func() {
gin.DefaultWriter = io.Discard
router := gin.New()
routerSetup(router)
s := &http.Server{
Addr: fmt.Sprintf("%s:%d", addr, port),
Handler: router,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
}
logrus.WithFields(logrus.Fields{
"func": "Service",
}).Infof("start service on %s:%d", addr, port)
if ssl == "true" {
log.Fatal(s.ListenAndServeTLS(sslPem, sslKey))
} else {
log.Fatal(s.ListenAndServe())
}
}()
}

19
gin/router-gin.go Normal file
View File

@ -0,0 +1,19 @@
package gin
import (
"github.com/gin-gonic/gin"
"myschools.me/heritage/heritage-api/handler"
)
// 路由配置
func routerSetup(router *gin.Engine) {
router.Use(gin.Recovery())
api := router.Group("/api")
api.POST("/login", handler.Login)
protected := router.Group("/api")
protected.Use(auth(), authorize())
protected.POST("/logout", handler.Logout)
protected.GET("/me", handler.Me)
protected.POST("/password", handler.ChangePassword)
}

54
go.mod Normal file
View File

@ -0,0 +1,54 @@
module myschools.me/heritage/heritage-api
go 1.26.1
require (
github.com/gin-gonic/gin v1.12.0
github.com/gomodule/redigo v1.9.3
github.com/google/uuid v1.6.0
github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible
github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5
github.com/sirupsen/logrus v1.9.4
gorm.io/driver/mysql v1.6.0
gorm.io/gorm v1.31.1
gorm.io/plugin/dbresolver v1.6.2
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jonboulle/clockwork v0.5.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lestrrat-go/strftime v1.1.1 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
)

147
handler/auth-handler.go Normal file
View File

@ -0,0 +1,147 @@
package handler
import (
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"myschools.me/heritage/heritage-api/model"
"myschools.me/heritage/heritage-api/service"
)
type loginRequest struct {
UserName string `json:"userName"`
Username string `json:"username"`
Password string `json:"password"`
}
type loginResponse struct {
Token string `json:"token"`
User *model.User `json:"user"`
}
type changePasswordRequest struct {
OldPassword string `json:"oldPassword"`
NewPassword string `json:"newPassword"`
}
func Login(c *gin.Context) {
var req loginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"data": "参数错误",
})
return
}
if req.UserName == "" {
req.UserName = req.Username
}
req.UserName = strings.TrimSpace(req.UserName)
if req.UserName == "" || req.Password == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"data": "用户名或密码不能为空",
})
return
}
token, safeUser, err := service.Login(req.UserName, req.Password)
if err != nil {
if err == service.ErrInvalidCredentials {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"data": "用户名或密码错误",
})
return
}
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"data": "登录失败",
})
return
}
c.JSON(http.StatusOK, loginResponse{
Token: token,
User: safeUser,
})
}
func Logout(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
token = c.Query("Authorization")
}
_ = service.Logout(token)
c.JSON(http.StatusOK, gin.H{
"data": "ok",
})
}
func Me(c *gin.Context) {
usr := currentUser(c)
if usr == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"data": "无效TOKEN, 请重新登录!",
})
return
}
c.JSON(http.StatusOK, gin.H{
"user": usr,
})
}
func ChangePassword(c *gin.Context) {
usr := currentUser(c)
if usr == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"data": "无效TOKEN, 请重新登录!",
})
return
}
var req changePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"data": "参数错误",
})
return
}
req.OldPassword = strings.TrimSpace(req.OldPassword)
req.NewPassword = strings.TrimSpace(req.NewPassword)
err := service.ChangePassword(usr.ID, req.OldPassword, req.NewPassword)
if err != nil {
switch err {
case service.ErrInvalidArgument:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"data": "密码不能为空",
})
case service.ErrNewPasswordShort:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"data": "新密码至少6位",
})
case service.ErrOldPasswordWrong:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"data": "旧密码错误",
})
case service.ErrUserNotFound:
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"data": "无效用户, 请重新登录!",
})
default:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"data": changePasswordInternalError(err),
})
}
return
}
c.JSON(http.StatusOK, gin.H{
"data": "ok",
})
}
func changePasswordInternalError(err error) string {
if os.Getenv("DEBUG") == "true" && err != nil {
return "修改失败: " + err.Error()
}
return "修改失败"
}

19
handler/base-handler.go Normal file
View File

@ -0,0 +1,19 @@
package handler
import (
"github.com/gin-gonic/gin"
"myschools.me/heritage/heritage-api/model"
)
// 获取当前用户
func currentUser(c *gin.Context) *model.User {
usr, ok := c.Get("user")
if !ok {
return nil
}
u, ok := usr.(*model.User)
if !ok || u == nil || u.ID == "" {
return nil
}
return u
}

1
handler/user-handler.go Normal file
View File

@ -0,0 +1 @@
package handler

77
logrus.go Normal file
View File

@ -0,0 +1,77 @@
package main
import (
"fmt"
"os"
"time"
rotatelogs "github.com/lestrrat-go/file-rotatelogs"
"github.com/rifflock/lfshook"
"github.com/sirupsen/logrus"
)
/*
注意当天文件放项目根目录下
env LOGLEVEL=debug
*/
func init() {
//日志初始化
level := os.Getenv("LOGLEVEL")
switch level {
case "debug":
logrus.SetLevel(logrus.DebugLevel)
case "info":
logrus.SetLevel(logrus.InfoLevel)
case "warn":
logrus.SetLevel(logrus.WarnLevel)
case "error":
logrus.SetLevel(logrus.ErrorLevel)
case "fatal":
logrus.SetLevel(logrus.FatalLevel)
default:
logrus.SetLevel(logrus.PanicLevel)
}
logrus.AddHook(newLfsHook(72))
}
func newLfsHook(maxRemainCnt uint) logrus.Hook {
//检查与创建日志文件夹
_, err := os.Stat("logs")
if os.IsNotExist(err) {
os.Mkdir("logs", 0755)
}
logName := fmt.Sprintf(`logs/%s`, "heritage")
writer, err := rotatelogs.New(
logName+"%Y%m%d.log",
// WithLinkName为最新的日志建立软连接以方便随着找到当前日志文件
rotatelogs.WithLinkName(logName),
// WithRotationTime设置日志分割的时间这里设置为一小时分割一次
rotatelogs.WithRotationTime(24*time.Hour),
// WithMaxAge和WithRotationCount二者只能设置一个
// WithMaxAge设置文件清理前的最长保存时间
// WithRotationCount设置文件清理前最多保存的个数。
//rotatelogs.WithMaxAge(time.Hour*24),
rotatelogs.WithRotationCount(maxRemainCnt),
)
if err != nil {
panic("config local file system for logger error: " + err.Error())
}
lfsHook := lfshook.NewHook(lfshook.WriterMap{
logrus.DebugLevel: writer,
logrus.InfoLevel: writer,
logrus.WarnLevel: writer,
logrus.ErrorLevel: writer,
logrus.FatalLevel: writer,
logrus.PanicLevel: writer,
}, &logrus.TextFormatter{DisableColors: true})
return lfsHook
}

19
main.go Normal file
View File

@ -0,0 +1,19 @@
package main
import (
"os"
"os/signal"
"github.com/sirupsen/logrus"
"myschools.me/heritage/heritage-api/gin"
)
func main() {
gin.Service()
// 等待服务关闭信号,使用通道
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt)
<-quit
logrus.Info("Shutdown Server ...")
}

View File

@ -0,0 +1,8 @@
package model
type Permission struct {
ID string `gorm:"type:varchar(32);primaryKey"`
RoleID string `gorm:"type:varchar(32);not null;index"`
Code string `gorm:"type:varchar(128);not null;index"`
Name string `gorm:"type:varchar(128);not null"`
}

7
model/role-model.go Normal file
View File

@ -0,0 +1,7 @@
package model
type Role struct {
ID string `gorm:"type:varchar(32);primaryKey"`
Code string `gorm:"type:varchar(64);not null;uniqueIndex"`
Name string `gorm:"type:varchar(64);not null"`
}

9
model/user-model.go Normal file
View File

@ -0,0 +1,9 @@
package model
type User struct {
ID string `json:"id" gorm:"type:varchar(32);primaryKey"`
UserName string `json:"userName" gorm:"type:varchar(20);not null;uniqueIndex"`
PasswordHash string `json:"-" gorm:"type:varchar(128);default:''"`
RoleID string `json:"roleId" gorm:"type:varchar(32);index"`
Role Role `json:"role" gorm:"foreignKey:RoleID;references:ID"`
}

105
mysql/mysql.go Normal file
View File

@ -0,0 +1,105 @@
package mysql
import (
"os"
"strconv"
"time"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/plugin/dbresolver"
)
/*
docker配置
env MYSQL_DSN=root:root@tcp(mysql:3306)/sample?charset=utf8mb4&parseTime=True&loc=Local
env MYSQL_MAXLIFETIME=2
env MYSQL_MAXIDLECONNS=2
env MYSQL_MAXOPENCONNS=200
env MYSQL_INIT=true
*/
var (
_db *gorm.DB
)
// 创建实例
func newDB() (*gorm.DB, error) {
if _db != nil {
return _db, nil
}
dsn := os.Getenv("MYSQL_DSN")
if dsn == "" {
dsn = "root:root@tcp(127.0.0.1:3306)/mysql?charset=utf8&parseTime=True&loc=Local"
}
maxLifetime := func() int {
c := os.Getenv("MYSQL_MAXLIFETIME")
cc, err := strconv.Atoi(c)
if err != nil {
return 1
}
if cc <= 0 {
return 1
}
if cc >= 1000 {
cc = 1000
}
return cc
}()
maxIdleConns := func() int {
c := os.Getenv("MYSQL_MAXIDLECONNS")
cc, err := strconv.Atoi(c)
if err != nil {
return 1
}
if cc < 0 {
return 0
}
if cc >= 1000 {
cc = 1000
}
return cc
}()
maxOpenConns := func() int {
c := os.Getenv("MYSQL_MAXOPENCONNS")
cc, err := strconv.Atoi(c)
if err != nil {
return 1
}
if cc < 0 {
return 0
}
if cc >= 1000 {
cc = 1000
}
return cc
}()
var err error
_db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
SkipDefaultTransaction: true,
Logger: logger.Default.LogMode(logger.Silent),
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
})
if err != nil {
return nil, err
}
_db.Use(
dbresolver.Register(dbresolver.Config{
Sources: []gorm.Dialector{mysql.Open(dsn)},
Replicas: []gorm.Dialector{mysql.Open(dsn)},
Policy: dbresolver.RandomPolicy{},
}).SetConnMaxIdleTime(time.Hour).
SetConnMaxLifetime(time.Duration(maxLifetime) * time.Hour).
SetMaxIdleConns(maxIdleConns).
SetMaxOpenConns(maxOpenConns))
return _db, nil
}

39
mysql/permission-mysql.go Normal file
View File

@ -0,0 +1,39 @@
package mysql
import "myschools.me/heritage/heritage-api/model"
func permissionDefined(permissionCode *string) (bool, error) {
db, err := newDB()
if err != nil {
return false, err
}
var count int64
if err := db.Model(&model.Permission{}).Where("code = ?", *permissionCode).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
func roleHasPermission(roleID, permissionCode *string) (bool, error) {
db, err := newDB()
if err != nil {
return false, err
}
var count int64
if err := db.Model(&model.Permission{}).
Where("role_id = ? AND (code = ? OR code = ?)", *roleID, *permissionCode, "*").
Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
func PermissionDefined(permissionCode *string) (bool, error) {
return permissionDefined(permissionCode)
}
func RoleHasPermission(roleID, permissionCode *string) (bool, error) {
return roleHasPermission(roleID, permissionCode)
}

89
mysql/tables-mysql.go Normal file
View File

@ -0,0 +1,89 @@
package mysql
import (
"os"
"strings"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"myschools.me/heritage/heritage-api/model"
)
func Bootstrap() {
if os.Getenv("MYSQL_INIT") != "true" {
return
}
db, err := newDB()
if err != nil {
panic(err)
}
if err := db.AutoMigrate(&model.Role{}, &model.Permission{}, &model.User{}); err != nil {
panic(err)
}
var roleCount int64
if err := db.Model(&model.Role{}).Count(&roleCount).Error; err != nil {
panic(err)
}
var permissionCount int64
if err := db.Model(&model.Permission{}).Count(&permissionCount).Error; err != nil {
panic(err)
}
var userCount int64
if err := db.Model(&model.User{}).Count(&userCount).Error; err != nil {
panic(err)
}
var defaultRole model.Role
if roleCount == 0 {
defaultRole = model.Role{
ID: newID(),
Code: "admin",
Name: "管理员",
}
if err := db.Create(&defaultRole).Error; err != nil {
panic(err)
}
} else {
if err := db.Where("code = ?", "admin").First(&defaultRole).Error; err != nil {
if err := db.First(&defaultRole).Error; err != nil {
panic(err)
}
}
}
if permissionCount == 0 {
p := model.Permission{
ID: newID(),
RoleID: defaultRole.ID,
Code: "*",
Name: "全部权限",
}
if err := db.Create(&p).Error; err != nil {
panic(err)
}
}
if userCount == 0 {
defaultPwd := "admin"
h, err := bcrypt.GenerateFromPassword([]byte(defaultPwd), bcrypt.DefaultCost)
if err != nil {
panic(err)
}
u := model.User{
ID: newID(),
UserName: "admin",
PasswordHash: string(h),
RoleID: defaultRole.ID,
}
if err := db.Create(&u).Error; err != nil {
panic(err)
}
}
}
func newID() string {
id := uuid.Must(uuid.NewV7()).String()
return strings.ReplaceAll(id, "-", "")
}

84
mysql/user-mysql.go Normal file
View File

@ -0,0 +1,84 @@
package mysql
import (
"errors"
"gorm.io/gorm"
"myschools.me/heritage/heritage-api/model"
)
func userByUserName(userName *string) (*model.User, bool, error) {
db, err := newDB()
if err != nil {
return nil, false, err
}
var u model.User
if err := db.Where("user_name = ?", *userName).First(&u).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, false, nil
}
return nil, false, err
}
return &u, true, nil
}
func userByID(userID *string) (*model.User, bool, error) {
db, err := newDB()
if err != nil {
return nil, false, err
}
var u model.User
if err := db.Where("id = ?", *userID).First(&u).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, false, nil
}
return nil, false, err
}
return &u, true, nil
}
func updateUserPasswordHash(userID, passwordHash *string) (bool, error) {
db, err := newDB()
if err != nil {
return false, err
}
tx := db.Model(&model.User{}).Where("id = ?", *userID).Update("password_hash", *passwordHash)
if tx.Error != nil {
return false, tx.Error
}
return tx.RowsAffected > 0, nil
}
func userRoleIDByUserID(userID *string) (*string, bool, error) {
db, err := newDB()
if err != nil {
return nil, false, err
}
var u model.User
if err := db.Select("id", "role_id").Where("id = ?", *userID).First(&u).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, false, nil
}
return nil, false, err
}
return &u.RoleID, true, nil
}
func UserRoleIDByUserID(userID *string) (*string, bool, error) {
return userRoleIDByUserID(userID)
}
func UserByUserName(userName *string) (*model.User, bool, error) {
return userByUserName(userName)
}
func UserByID(userID *string) (*model.User, bool, error) {
return userByID(userID)
}
func UpdateUserPasswordHash(userID, passwordHash *string) (bool, error) {
return updateUserPasswordHash(userID, passwordHash)
}

169
redis/redis.go Normal file
View File

@ -0,0 +1,169 @@
package redis
import (
"encoding/json"
"os"
"strconv"
"time"
redigo "github.com/gomodule/redigo/redis"
)
/*
env REDIS_DSN=127.0.0.1:6379
env REDIS_PWD=
env REDIS_DB=0
*/
var pool *redigo.Pool
func init() {
dbNumber := func() int {
db := os.Getenv("REDIS_DB")
database, err := strconv.Atoi(db)
if err != nil {
return 0
}
if database < 0 {
database = 0
}
if database > 16 {
database = 0
}
return database
}()
pool = &redigo.Pool{
MaxActive: 100,
MaxIdle: 1,
IdleTimeout: time.Second * time.Duration(60),
Dial: func() (redigo.Conn, error) {
return redigo.Dial("tcp", os.Getenv("REDIS_DSN"),
redigo.DialDatabase(dbNumber),
redigo.DialPassword(os.Getenv("REDIS_PWD")),
)
},
TestOnBorrow: func(conn redigo.Conn, t time.Time) error {
if time.Since(t) < time.Minute {
return nil
}
_, err := conn.Do("PING")
return err
},
}
}
// GetBytes 获取一个字节数组值
func getBytes(key *string) (*[]byte, error) {
conn := pool.Get()
defer conn.Close()
data, err := redigo.Bytes(conn.Do("GET", *key))
return &data, err
}
// Get 获取一个值
func get(key string) interface{} {
conn := pool.Get()
defer conn.Close()
var data []byte
var err error
if data, err = redigo.Bytes(conn.Do("GET", key)); err != nil {
return nil
}
var reply interface{}
if err = json.Unmarshal(data, &reply); err != nil {
return nil
}
return reply
}
// 集合Set增加元素
func setAdd(key string, data ...interface{}) error {
conn := pool.Get()
defer conn.Close()
var err error
for _, d := range data {
_, e := conn.Do("SADD", key, d)
if e != nil {
err = e
break
}
}
return err
}
// 集合Set删除元素
func setRem(key string, data ...interface{}) error {
conn := pool.Get()
defer conn.Close()
var err error
for _, d := range data {
_, e := conn.Do("SREM", key, d)
if e != nil {
err = e
break
}
}
return err
}
// 集合Set判断是否存在成员member结果.(int64)==1表示存在
func setIsMember(key string, member interface{}) (interface{}, error) {
conn := pool.Get()
defer conn.Close()
return conn.Do("SISMEMBER", key, member)
}
// 设置一个值
func set(key *string, val interface{}, timeout time.Duration) error {
data, err := json.Marshal(val)
if err != nil {
return err
}
return setBytes(key, &data, timeout)
}
func setBytes(key *string, data *[]byte, timeout time.Duration) error {
conn := pool.Get()
defer conn.Close()
_, err := conn.Do("SETEX", *key, int64(timeout/time.Second), *data)
return err
}
// IsExist 判断key是否存在
func isExist(key string) bool {
conn := pool.Get()
defer conn.Close()
a, _ := conn.Do("EXISTS", key)
i := a.(int64)
return i > 0
}
// Delete 删除
func delete(key string) error {
conn := pool.Get()
defer conn.Close()
if _, err := conn.Do("DEL", key); err != nil {
return err
}
return nil
}
// Expire 失效时间配置
func expire(key string, t int64) error {
conn := pool.Get()
defer conn.Close()
if _, err := conn.Do("expire", key, t); err != nil {
return err
}
return nil
}

46
redis/user-redis.go Normal file
View File

@ -0,0 +1,46 @@
package redis
import (
"encoding/json"
"errors"
"time"
"myschools.me/heritage/heritage-api/model"
"github.com/sirupsen/logrus"
)
// 存储用户的Token
func UserTokenSet(key *string, usr *model.User) error {
err := set(key, usr, 7210*time.Second)
if err != nil {
logrus.WithFields(logrus.Fields{
"func": "UserTokenSet",
}).Warnf("Set: %s", err.Error())
return err
}
return nil
}
// 从redis中获取用户信息最佳实践经验建议把此代码放service层
func UserTokenGet(token *string) (*model.User, error) {
b, err := getBytes(token)
if err != nil {
return nil, err
}
if b == nil {
return nil, errors.New("无效token,请重新登录!")
}
user := new(model.User)
if err := json.Unmarshal(*b, user); err != nil {
return nil, err
}
return user, nil
}
func UserTokenDel(token *string) error {
return delete(*token)
}

115
service/auth-service.go Normal file
View File

@ -0,0 +1,115 @@
package service
import (
"errors"
"strings"
"github.com/sirupsen/logrus"
"myschools.me/heritage/heritage-api/model"
"myschools.me/heritage/heritage-api/mysql"
"myschools.me/heritage/heritage-api/redis"
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidArgument = errors.New("invalid argument")
ErrOldPasswordWrong = errors.New("old password wrong")
ErrNewPasswordShort = errors.New("new password too short")
ErrUserNotFound = errors.New("user not found")
)
func Login(userName, plainPassword string) (string, *model.User, error) {
userName = strings.TrimSpace(userName)
if userName == "" || plainPassword == "" {
return "", nil, ErrInvalidCredentials
}
u, found, err := mysql.UserByUserName(&userName)
if err != nil {
logrus.WithFields(logrus.Fields{
"func": "service.Login",
"userName": userName,
}).Errorf("mysql.UserByUserName: %s", err.Error())
return "", nil, err
}
if !found || u == nil || u.PasswordHash == "" {
return "", nil, ErrInvalidCredentials
}
if !VerifyPassword(u.PasswordHash, plainPassword) {
return "", nil, ErrInvalidCredentials
}
token := newToken()
safeUser := &model.User{
ID: u.ID,
UserName: u.UserName,
RoleID: u.RoleID,
}
if err := redis.UserTokenSet(&token, safeUser); err != nil {
logrus.WithFields(logrus.Fields{
"func": "service.Login",
"userID": u.ID,
}).Errorf("redis.UserTokenSet: %s", err.Error())
return "", nil, err
}
return token, safeUser, nil
}
func Logout(token string) error {
token = strings.TrimSpace(token)
token = strings.TrimPrefix(token, "Bearer ")
token = strings.TrimPrefix(token, "bearer ")
if token == "" {
return nil
}
return redis.UserTokenDel(&token)
}
func ChangePassword(userID, oldPassword, newPassword string) error {
oldPassword = strings.TrimSpace(oldPassword)
newPassword = strings.TrimSpace(newPassword)
if oldPassword == "" || newPassword == "" {
return ErrInvalidArgument
}
if len(newPassword) < 6 {
return ErrNewPasswordShort
}
dbUser, found, err := mysql.UserByID(&userID)
if err != nil {
logrus.WithFields(logrus.Fields{
"func": "service.ChangePassword",
"userID": userID,
}).Errorf("mysql.UserByID: %s", err.Error())
return err
}
if !found || dbUser == nil || dbUser.PasswordHash == "" {
return ErrUserNotFound
}
if !VerifyPassword(dbUser.PasswordHash, oldPassword) {
return ErrOldPasswordWrong
}
hash, err := HashPassword(newPassword)
if err != nil {
logrus.WithFields(logrus.Fields{
"func": "service.ChangePassword",
"userID": userID,
}).Errorf("password.HashPassword: %s", err.Error())
return err
}
updated, err := mysql.UpdateUserPasswordHash(&userID, &hash)
if err != nil {
logrus.WithFields(logrus.Fields{
"func": "service.ChangePassword",
"userID": userID,
}).Errorf("mysql.UpdateUserPasswordHash: %s", err.Error())
return err
}
if !updated {
return ErrUserNotFound
}
return nil
}

17
service/base-service.go Normal file
View File

@ -0,0 +1,17 @@
package service
import (
"strings"
"github.com/google/uuid"
)
func NewID() string {
i := uuid.Must(uuid.NewV7()).String()
return strings.ReplaceAll(i, "-", "")
}
func newToken() string {
i := uuid.Must(uuid.NewV7()).String()
return strings.ReplaceAll(i, "-", "")
}

View File

@ -0,0 +1,15 @@
package service
import "golang.org/x/crypto/bcrypt"
func HashPassword(password string) (string, error) {
h, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(h), nil
}
func VerifyPassword(hash, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
}

15
service/rbac-service.go Normal file
View File

@ -0,0 +1,15 @@
package service
import "myschools.me/heritage/heritage-api/mysql"
func PermissionDefined(permissionCode string) (bool, error) {
return mysql.PermissionDefined(&permissionCode)
}
func UserRoleIDByUserID(userID string) (*string, bool, error) {
return mysql.UserRoleIDByUserID(&userID)
}
func RoleHasPermission(roleID *string, permissionCode string) (bool, error) {
return mysql.RoleHasPermission(roleID, &permissionCode)
}