package auth import ( "cls/pkg/logger" "context" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v4" "github.com/redis/go-redis/v9" "net/http" "time" ) const ( JWTClaimsKey = "JWT_CLAIMS" ) // Claims JWT声明 type Claims struct { Username string `json:"username"` // 加密后的手机号,空表示游客 GuestId string `json:"guest_id"` // 游客ID,用于标识未登录用户 Openid string `json:"openid"` jwt.RegisteredClaims } // AuthMiddleware 认证中间件 type AuthMiddleware struct { config *Config secretKey []byte log logger.Logger redis redis.Cmdable } // NewAuthMiddleware 创建认证中间件 func NewAuthMiddleware(secretKey string, config *Config, redis redis.Cmdable, log logger.New) *AuthMiddleware { if config == nil { config = DefaultConfig() } return &AuthMiddleware{ config: config, secretKey: []byte(secretKey), log: log("cls:infrastructure:middleware"), redis: redis, } } func (m *AuthMiddleware) HandleCallback(openid string) string { t, _, err := m.generateGuestToken(openid) if err != nil { m.log.Error(err) return "" } return t } // Handle 处理认证逻辑 func (m *AuthMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { path := c.Request.URL.Path // 获取token token := m.ExtractToken(c) // 如果没有token,创建临时游客token if token == "" { if !m.shouldSkip(path) { c.Next() return } else { c.AbortWithStatus(http.StatusUnauthorized) return } } // 验证已有token claims, err := m.validateToken(token) if err != nil { if !m.shouldSkip(path) { c.Next() return } else { c.AbortWithStatus(http.StatusUnauthorized) return } } // token有效,根据phone判断是否为游客 isGuest := claims.Username == "" // 设置 claims 到上下文 mapClaims := map[string]interface{}{} if isGuest { mapClaims["guest_id"] = claims.GuestId c.Set("guest_id", claims.GuestId) } else { mapClaims["username"] = claims.Username c.Set("username", claims.Username) } mapClaims["openid"] = claims.Openid c.Set("openid", claims.Openid) c.Set(JWTClaimsKey, mapClaims) c.Set("is_guest", isGuest) // 如果是跳过认证的路径,直接继续 if m.shouldSkip(path) { c.Next() return } c.Next() } } // generateGuestToken 生成游客token,返回token字符串和游客ID func (m *AuthMiddleware) generateGuestToken(openid ...string) (string, string, error) { guestId := generateGuestId() o := "" fmt.Println("have openid ?,", openid) if len(openid) != 0 { o = openid[0] } claims := &Claims{ Username: "", GuestId: guestId, Openid: o, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), Issuer: "cls-system", Subject: "guest-access", Audience: []string{"cls-api"}, }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(m.secretKey) if err != nil { return "", "", err } return tokenString, guestId, nil } // GenerateUserToken 生成用户token func (m *AuthMiddleware) GenerateUserToken(encryptedPhone string, openid string) (string, error) { claims := &Claims{ Username: encryptedPhone, GuestId: "", // 登录用户不需要游客ID Openid: openid, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * 24 * time.Hour)), // 30天过期 IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), Issuer: "cls-system", Subject: "user-access", Audience: []string{"cls-api"}, }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, err := token.SignedString(m.secretKey) if err != nil { return "", err } err = m.redis.Set(context.Background(), encryptedPhone, tokenStr, time.Now().Add(30*24*time.Hour).Sub(time.Now())).Err() if err != nil { return "", err } return tokenStr, nil } // validateToken 验证token func (m *AuthMiddleware) validateToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return m.secretKey, nil }) if err != nil { return nil, fmt.Errorf("解析token失败: %v", err) } claims, ok := token.Claims.(*Claims) if ok && token.Valid { if claims.Username != "" { result, _ := m.redis.Get(context.Background(), claims.Username).Result() if result != tokenString { return nil, errors.New("用户token失效") } } return claims, nil } return nil, fmt.Errorf("无效的token") } // extractToken 从请求头中提取token func (m *AuthMiddleware) ExtractToken(c *gin.Context) string { auth := c.GetHeader(m.config.TokenKey) if auth == "" { return "" } const prefix = "Bearer " if len(auth) <= len(prefix) || auth[:len(prefix)] != prefix { return "" } return auth[len(prefix):] } // shouldSkip 检查是否需要跳过认证 func (m *AuthMiddleware) shouldSkip(path string) bool { for _, skipPath := range m.config.SkipPaths { if path == skipPath { return true } } return false } // generateGuestId 生成游客ID func generateGuestId() string { return "G" + time.Now().Format("060102150405") + randomString(6) } // randomString 生成随机字符串 func randomString(n int) string { const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, n) for i := range b { b[i] = letters[time.Now().UnixNano()%int64(len(letters))] } return string(b) } // DecodeToken 从 gin.Context 中解析 token 并返回原始的加密手机号 func (m *AuthMiddleware) DecodeToken(c *gin.Context) (string, error) { // 从请求头中提取 token tokenString := m.ExtractToken(c) if tokenString == "" { return "", fmt.Errorf("未找到有效的 token") } // 解析 token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return m.secretKey, nil }) if err != nil { return "", fmt.Errorf("解析 token 失败: %v", err) } // 验证 token 并获取 claims if claims, ok := token.Claims.(*Claims); ok && token.Valid { // 返回原始的加密手机号 return claims.Username, nil } return "", fmt.Errorf("无效的 token") } // GetSecretKey 获取密钥 func (m *AuthMiddleware) GetSecretKey() []byte { return m.secretKey } // DecodeTokenStatic 静态方法,不需要 AuthMiddleware 实例,直接解析 token 字符串 func DecodeTokenStatic(tokenString string, secretKey []byte) (string, error) { // 移除 Bearer 前缀(如果有) if len(tokenString) > 7 && tokenString[:7] == "Bearer " { tokenString = tokenString[7:] } // 解析 token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return secretKey, nil }) if err != nil { return "", fmt.Errorf("解析 token 失败: %v", err) } // 验证 token 并获取 claims if claims, ok := token.Claims.(*Claims); ok && token.Valid { // 返回原始的加密手机号 return claims.Username, nil } return "", fmt.Errorf("无效的 token") }