You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

317 lines
8.0 KiB
Go

1 month ago
package auth
import (
"cls/pkg/logger"
"context"
"errors"
1 month ago
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"github.com/redis/go-redis/v9"
1 month ago
"time"
)
const (
JWTClaimsKey = "JWT_CLAIMS"
)
// Claims JWT声明
type Claims struct {
Username string `json:"username"` // 加密后的手机号,空表示游客
GuestId string `json:"guest_id"` // 游客ID用于标识未登录用户
jwt.RegisteredClaims
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
config *Config
secretKey []byte
log logger.Logger
redis redis.Cmdable
1 month ago
}
// NewAuthMiddleware 创建认证中间件
func NewAuthMiddleware(secretKey string, config *Config, redis redis.Cmdable, log logger.New) *AuthMiddleware {
1 month ago
if config == nil {
config = DefaultConfig()
}
return &AuthMiddleware{
config: config,
secretKey: []byte(secretKey),
log: log("cls:infrastructure:middleware"),
redis: redis,
1 month ago
}
}
// Handle 处理认证逻辑
func (m *AuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
// 获取token
token := m.ExtractToken(c)
// 如果没有token创建临时游客token
if token == "" {
newToken, guestId, err := m.generateGuestToken()
if err != nil {
c.JSON(500, gin.H{"error": "生成临时token失败"})
c.Abort()
return
}
// 设置token到响应头
c.Header(m.config.TokenKey, "Bearer "+newToken)
// 设置 claims 到上下文
mapClaims := map[string]interface{}{
"guest_id": guestId,
}
c.Set(JWTClaimsKey, mapClaims)
c.Set("is_guest", true)
c.Set("guest_id", guestId)
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
c.Next()
return
}
c.Next()
return
}
// 验证已有token
claims, err := m.validateToken(token)
if err != nil {
// token无效创建新的临时游客token
newToken, guestId, err := m.generateGuestToken()
if err != nil {
c.JSON(500, gin.H{"error": "生成临时token失败"})
c.Abort()
return
}
c.Header(m.config.TokenKey, "Bearer "+newToken)
// 设置 claims 到上下文
mapClaims := map[string]interface{}{
"guest_id": guestId,
}
c.Set(JWTClaimsKey, mapClaims)
c.Set("is_guest", true)
c.Set("guest_id", guestId)
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
c.Next()
return
}
c.Next()
return
}
// token有效根据phone判断是否为游客
isGuest := claims.Username == ""
// 设置 claims 到上下文
mapClaims := map[string]interface{}{}
if isGuest {
mapClaims["guest_id"] = claims.GuestId
c.Set("guest_id", claims.GuestId)
} else {
mapClaims["username"] = claims.Username
c.Set("username", claims.Username)
}
c.Set(JWTClaimsKey, mapClaims)
c.Set("is_guest", isGuest)
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
c.Next()
return
}
c.Next()
}
}
// generateGuestToken 生成游客token返回token字符串和游客ID
func (m *AuthMiddleware) generateGuestToken() (string, string, error) {
guestId := generateGuestId()
claims := &Claims{
Username: "",
GuestId: guestId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "cls-system",
Subject: "guest-access",
Audience: []string{"cls-api"},
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(m.secretKey)
if err != nil {
return "", "", err
}
return tokenString, guestId, nil
}
// GenerateUserToken 生成用户token
func (m *AuthMiddleware) GenerateUserToken(encryptedPhone string) (string, error) {
claims := &Claims{
Username: encryptedPhone,
GuestId: "", // 登录用户不需要游客ID
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * 24 * time.Hour)), // 30天过期
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "cls-system",
Subject: "user-access",
Audience: []string{"cls-api"},
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString(m.secretKey)
if err != nil {
return "", err
}
err = m.redis.Set(context.Background(), encryptedPhone, tokenStr, time.Now().Add(30*24*time.Hour).Sub(time.Now())).Err()
if err != nil {
return "", err
}
return tokenStr, nil
1 month ago
}
// validateToken 验证token
func (m *AuthMiddleware) validateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return nil, fmt.Errorf("解析token失败: %v", err)
}
claims, ok := token.Claims.(*Claims)
if ok && token.Valid {
if claims.Username != "" {
result, _ := m.redis.Get(context.Background(), claims.Username).Result()
if result != tokenString {
return nil, errors.New("用户token失效")
}
}
1 month ago
return claims, nil
}
return nil, fmt.Errorf("无效的token")
}
// extractToken 从请求头中提取token
func (m *AuthMiddleware) ExtractToken(c *gin.Context) string {
auth := c.GetHeader(m.config.TokenKey)
if auth == "" {
return ""
}
const prefix = "Bearer "
if len(auth) <= len(prefix) || auth[:len(prefix)] != prefix {
return ""
}
return auth[len(prefix):]
}
// shouldSkip 检查是否需要跳过认证
func (m *AuthMiddleware) shouldSkip(path string) bool {
for _, skipPath := range m.config.SkipPaths {
if path == skipPath {
return true
}
}
return false
}
// generateGuestId 生成游客ID
func generateGuestId() string {
return "G" + time.Now().Format("060102150405") + randomString(6)
}
// randomString 生成随机字符串
func randomString(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[time.Now().UnixNano()%int64(len(letters))]
}
return string(b)
}
// DecodeToken 从 gin.Context 中解析 token 并返回原始的加密手机号
func (m *AuthMiddleware) DecodeToken(c *gin.Context) (string, error) {
// 从请求头中提取 token
tokenString := m.ExtractToken(c)
if tokenString == "" {
return "", fmt.Errorf("未找到有效的 token")
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}
// GetSecretKey 获取密钥
func (m *AuthMiddleware) GetSecretKey() []byte {
return m.secretKey
}
// DecodeTokenStatic 静态方法,不需要 AuthMiddleware 实例,直接解析 token 字符串
func DecodeTokenStatic(tokenString string, secretKey []byte) (string, error) {
// 移除 Bearer 前缀(如果有)
if len(tokenString) > 7 && tokenString[:7] == "Bearer " {
tokenString = tokenString[7:]
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}