You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

303 lines
7.5 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package auth
import (
"cls/pkg/logger"
"context"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"github.com/redis/go-redis/v9"
"net/http"
"time"
)
const (
JWTClaimsKey = "JWT_CLAIMS"
)
// Claims JWT声明
type Claims struct {
Username string `json:"username"` // 加密后的手机号,空表示游客
GuestId string `json:"guest_id"` // 游客ID用于标识未登录用户
Openid string `json:"openid"`
jwt.RegisteredClaims
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
config *Config
secretKey []byte
log logger.Logger
redis redis.Cmdable
}
// NewAuthMiddleware 创建认证中间件
func NewAuthMiddleware(secretKey string, config *Config, redis redis.Cmdable, log logger.New) *AuthMiddleware {
if config == nil {
config = DefaultConfig()
}
return &AuthMiddleware{
config: config,
secretKey: []byte(secretKey),
log: log("cls:infrastructure:middleware"),
redis: redis,
}
}
func (m *AuthMiddleware) HandleCallback(openid string) string {
t, _, err := m.generateGuestToken(openid)
if err != nil {
m.log.Error(err)
return ""
}
return t
}
// Handle 处理认证逻辑
func (m *AuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
// 获取token
token := m.ExtractToken(c)
// 如果没有token创建临时游客token
if token == "" {
if !m.shouldSkip(path) {
c.Next()
return
} else {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
}
// 验证已有token
claims, err := m.validateToken(token)
if err != nil {
if !m.shouldSkip(path) {
c.Next()
return
} else {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
}
// token有效根据phone判断是否为游客
isGuest := claims.Username == ""
// 设置 claims 到上下文
mapClaims := map[string]interface{}{}
if isGuest {
mapClaims["guest_id"] = claims.GuestId
c.Set("guest_id", claims.GuestId)
} else {
mapClaims["username"] = claims.Username
c.Set("username", claims.Username)
}
mapClaims["openid"] = claims.Openid
c.Set("openid", claims.Openid)
c.Set(JWTClaimsKey, mapClaims)
c.Set("is_guest", isGuest)
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
c.Next()
return
}
c.Next()
}
}
// generateGuestToken 生成游客token返回token字符串和游客ID
func (m *AuthMiddleware) generateGuestToken(openid ...string) (string, string, error) {
guestId := generateGuestId()
o := ""
fmt.Println("have openid ?,", openid)
if len(openid) != 0 {
o = openid[0]
}
claims := &Claims{
Username: "",
GuestId: guestId,
Openid: o,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "cls-system",
Subject: "guest-access",
Audience: []string{"cls-api"},
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(m.secretKey)
if err != nil {
return "", "", err
}
return tokenString, guestId, nil
}
// GenerateUserToken 生成用户token
func (m *AuthMiddleware) GenerateUserToken(encryptedPhone string, openid string) (string, error) {
claims := &Claims{
Username: encryptedPhone,
GuestId: "", // 登录用户不需要游客ID
Openid: openid,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * 24 * time.Hour)), // 30天过期
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "cls-system",
Subject: "user-access",
Audience: []string{"cls-api"},
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString(m.secretKey)
if err != nil {
return "", err
}
err = m.redis.Set(context.Background(), encryptedPhone, tokenStr, time.Now().Add(30*24*time.Hour).Sub(time.Now())).Err()
if err != nil {
return "", err
}
return tokenStr, nil
}
// validateToken 验证token
func (m *AuthMiddleware) validateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return nil, fmt.Errorf("解析token失败: %v", err)
}
claims, ok := token.Claims.(*Claims)
if ok && token.Valid {
if claims.Username != "" {
result, _ := m.redis.Get(context.Background(), claims.Username).Result()
if result != tokenString {
return nil, errors.New("用户token失效")
}
}
return claims, nil
}
return nil, fmt.Errorf("无效的token")
}
// extractToken 从请求头中提取token
func (m *AuthMiddleware) ExtractToken(c *gin.Context) string {
auth := c.GetHeader(m.config.TokenKey)
if auth == "" {
return ""
}
const prefix = "Bearer "
if len(auth) <= len(prefix) || auth[:len(prefix)] != prefix {
return ""
}
return auth[len(prefix):]
}
// shouldSkip 检查是否需要跳过认证
func (m *AuthMiddleware) shouldSkip(path string) bool {
for _, skipPath := range m.config.SkipPaths {
if path == skipPath {
return true
}
}
return false
}
// generateGuestId 生成游客ID
func generateGuestId() string {
return "G" + time.Now().Format("060102150405") + randomString(6)
}
// randomString 生成随机字符串
func randomString(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[time.Now().UnixNano()%int64(len(letters))]
}
return string(b)
}
// DecodeToken 从 gin.Context 中解析 token 并返回原始的加密手机号
func (m *AuthMiddleware) DecodeToken(c *gin.Context) (string, error) {
// 从请求头中提取 token
tokenString := m.ExtractToken(c)
if tokenString == "" {
return "", fmt.Errorf("未找到有效的 token")
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}
// GetSecretKey 获取密钥
func (m *AuthMiddleware) GetSecretKey() []byte {
return m.secretKey
}
// DecodeTokenStatic 静态方法,不需要 AuthMiddleware 实例,直接解析 token 字符串
func DecodeTokenStatic(tokenString string, secretKey []byte) (string, error) {
// 移除 Bearer 前缀(如果有)
if len(tokenString) > 7 && tokenString[:7] == "Bearer " {
tokenString = tokenString[7:]
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}