You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
8.2 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package auth
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"time"
)
const (
JWTClaimsKey = "JWT_CLAIMS"
)
// Claims JWT声明
type Claims struct {
Username string `json:"username"` // 加密后的手机号,空表示游客
GuestId string `json:"guest_id"` // 游客ID用于标识未登录用户
jwt.RegisteredClaims
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
config *Config
secretKey []byte
}
// NewAuthMiddleware 创建认证中间件
func NewAuthMiddleware(secretKey string, config *Config) *AuthMiddleware {
if config == nil {
config = DefaultConfig()
}
return &AuthMiddleware{
config: config,
secretKey: []byte(secretKey),
}
}
// Handle 处理认证逻辑
func (m *AuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
fmt.Printf("处理请求: %s\n", path)
// 获取token
token := m.ExtractToken(c)
fmt.Printf("获取到的token: %s\n", token)
// 如果没有token创建临时游客token
if token == "" {
fmt.Println("没有token创建临时游客token")
newToken, guestId, err := m.generateGuestToken()
if err != nil {
fmt.Printf("生成临时token失败: %v\n", err)
c.JSON(500, gin.H{"error": "生成临时token失败"})
c.Abort()
return
}
fmt.Printf("生成的游客ID: %s\n", guestId)
// 设置token到响应头
c.Header(m.config.TokenKey, "Bearer "+newToken)
// 设置 claims 到上下文
mapClaims := map[string]interface{}{
"guest_id": guestId,
}
c.Set(JWTClaimsKey, mapClaims)
c.Set("is_guest", true)
c.Set("guest_id", guestId)
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
fmt.Printf("跳过认证: %s\n", path)
c.Next()
return
}
c.Next()
return
}
// 验证已有token
claims, err := m.validateToken(token)
if err != nil {
fmt.Printf("token验证失败: %v\n", err)
// token无效创建新的临时游客token
newToken, guestId, err := m.generateGuestToken()
if err != nil {
fmt.Printf("生成临时token失败: %v\n", err)
c.JSON(500, gin.H{"error": "生成临时token失败"})
c.Abort()
return
}
fmt.Printf("生成的游客ID: %s\n", guestId)
c.Header(m.config.TokenKey, "Bearer "+newToken)
// 设置 claims 到上下文
mapClaims := map[string]interface{}{
"guest_id": guestId,
}
c.Set(JWTClaimsKey, mapClaims)
c.Set("is_guest", true)
c.Set("guest_id", guestId)
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
fmt.Printf("跳过认证: %s\n", path)
c.Next()
return
}
c.Next()
return
}
fmt.Printf("token验证成功claims: %+v\n", claims)
// token有效根据phone判断是否为游客
isGuest := claims.Username == ""
// 设置 claims 到上下文
mapClaims := map[string]interface{}{}
if isGuest {
fmt.Printf("设置游客ID: %s\n", claims.GuestId)
mapClaims["guest_id"] = claims.GuestId
c.Set("guest_id", claims.GuestId)
} else {
fmt.Printf("设置用户手机号: %s\n", claims.Username)
mapClaims["username"] = claims.Username
c.Set("username", claims.Username)
}
c.Set(JWTClaimsKey, mapClaims)
c.Set("is_guest", isGuest)
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
fmt.Printf("跳过认证: %s\n", path)
c.Next()
return
}
c.Next()
}
}
// generateGuestToken 生成游客token返回token字符串和游客ID
func (m *AuthMiddleware) generateGuestToken() (string, string, error) {
guestId := generateGuestId()
claims := &Claims{
Username: "",
GuestId: guestId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "cls-system",
Subject: "guest-access",
Audience: []string{"cls-api"},
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(m.secretKey)
if err != nil {
return "", "", err
}
return tokenString, guestId, nil
}
// GenerateUserToken 生成用户token
func (m *AuthMiddleware) GenerateUserToken(encryptedPhone string) (string, error) {
fmt.Println("===GenerateUserToken=>>>>>>>>>", encryptedPhone)
claims := &Claims{
Username: encryptedPhone,
GuestId: "", // 登录用户不需要游客ID
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * 24 * time.Hour)), // 30天过期
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "cls-system",
Subject: "user-access",
Audience: []string{"cls-api"},
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(m.secretKey)
}
// validateToken 验证token
func (m *AuthMiddleware) validateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return nil, fmt.Errorf("解析token失败: %v", err)
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, fmt.Errorf("无效的token")
}
// extractToken 从请求头中提取token
func (m *AuthMiddleware) ExtractToken(c *gin.Context) string {
auth := c.GetHeader(m.config.TokenKey)
fmt.Printf("Authorization头: %s\n", auth)
if auth == "" {
return ""
}
const prefix = "Bearer "
if len(auth) <= len(prefix) || auth[:len(prefix)] != prefix {
fmt.Println("token格式不正确缺少Bearer前缀")
return ""
}
return auth[len(prefix):]
}
// shouldSkip 检查是否需要跳过认证
func (m *AuthMiddleware) shouldSkip(path string) bool {
for _, skipPath := range m.config.SkipPaths {
if path == skipPath {
return true
}
}
return false
}
// generateGuestId 生成游客ID
func generateGuestId() string {
return "G" + time.Now().Format("060102150405") + randomString(6)
}
// randomString 生成随机字符串
func randomString(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[time.Now().UnixNano()%int64(len(letters))]
}
return string(b)
}
// DecodeToken 从 gin.Context 中解析 token 并返回原始的加密手机号
func (m *AuthMiddleware) DecodeToken(c *gin.Context) (string, error) {
// 从请求头中提取 token
tokenString := m.ExtractToken(c)
if tokenString == "" {
return "", fmt.Errorf("未找到有效的 token")
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}
// GetSecretKey 获取密钥
func (m *AuthMiddleware) GetSecretKey() []byte {
return m.secretKey
}
// DecodeTokenStatic 静态方法,不需要 AuthMiddleware 实例,直接解析 token 字符串
func DecodeTokenStatic(tokenString string, secretKey []byte) (string, error) {
// 移除 Bearer 前缀(如果有)
if len(tokenString) > 7 && tokenString[:7] == "Bearer " {
tokenString = tokenString[7:]
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}