|
|
|
|
package auth
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"cls/pkg/logger"
|
|
|
|
|
"context"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
|
"net/http"
|
|
|
|
|
"time"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
JWTClaimsKey = "JWT_CLAIMS"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Claims JWT声明
|
|
|
|
|
type Claims struct {
|
|
|
|
|
Username string `json:"username"` // 加密后的手机号,空表示游客
|
|
|
|
|
GuestId string `json:"guest_id"` // 游客ID,用于标识未登录用户
|
|
|
|
|
Openid string `json:"openid"`
|
|
|
|
|
jwt.RegisteredClaims
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// AuthMiddleware 认证中间件
|
|
|
|
|
type AuthMiddleware struct {
|
|
|
|
|
config *Config
|
|
|
|
|
secretKey []byte
|
|
|
|
|
log logger.Logger
|
|
|
|
|
redis redis.Cmdable
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewAuthMiddleware 创建认证中间件
|
|
|
|
|
func NewAuthMiddleware(secretKey string, config *Config, redis redis.Cmdable, log logger.New) *AuthMiddleware {
|
|
|
|
|
if config == nil {
|
|
|
|
|
config = DefaultConfig()
|
|
|
|
|
}
|
|
|
|
|
return &AuthMiddleware{
|
|
|
|
|
config: config,
|
|
|
|
|
secretKey: []byte(secretKey),
|
|
|
|
|
log: log("cls:infrastructure:middleware"),
|
|
|
|
|
redis: redis,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *AuthMiddleware) HandleCallback(openid string) string {
|
|
|
|
|
t, _, err := m.generateGuestToken(openid)
|
|
|
|
|
if err != nil {
|
|
|
|
|
m.log.Error(err)
|
|
|
|
|
return ""
|
|
|
|
|
}
|
|
|
|
|
return t
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle 处理认证逻辑
|
|
|
|
|
func (m *AuthMiddleware) Handle() gin.HandlerFunc {
|
|
|
|
|
return func(c *gin.Context) {
|
|
|
|
|
path := c.Request.URL.Path
|
|
|
|
|
|
|
|
|
|
// 获取token
|
|
|
|
|
token := m.ExtractToken(c)
|
|
|
|
|
// 如果没有token,创建临时游客token
|
|
|
|
|
if token == "" {
|
|
|
|
|
if !m.shouldSkip(path) {
|
|
|
|
|
c.Next()
|
|
|
|
|
return
|
|
|
|
|
} else {
|
|
|
|
|
c.AbortWithStatus(http.StatusUnauthorized)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 验证已有token
|
|
|
|
|
claims, err := m.validateToken(token)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if !m.shouldSkip(path) {
|
|
|
|
|
c.Next()
|
|
|
|
|
return
|
|
|
|
|
} else {
|
|
|
|
|
c.AbortWithStatus(http.StatusUnauthorized)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// token有效,根据phone判断是否为游客
|
|
|
|
|
isGuest := claims.Username == ""
|
|
|
|
|
|
|
|
|
|
// 设置 claims 到上下文
|
|
|
|
|
mapClaims := map[string]interface{}{}
|
|
|
|
|
if isGuest {
|
|
|
|
|
mapClaims["guest_id"] = claims.GuestId
|
|
|
|
|
c.Set("guest_id", claims.GuestId)
|
|
|
|
|
} else {
|
|
|
|
|
mapClaims["username"] = claims.Username
|
|
|
|
|
c.Set("username", claims.Username)
|
|
|
|
|
}
|
|
|
|
|
mapClaims["openid"] = claims.Openid
|
|
|
|
|
c.Set("openid", claims.Openid)
|
|
|
|
|
|
|
|
|
|
c.Set(JWTClaimsKey, mapClaims)
|
|
|
|
|
c.Set("is_guest", isGuest)
|
|
|
|
|
|
|
|
|
|
// 如果是跳过认证的路径,直接继续
|
|
|
|
|
if m.shouldSkip(path) {
|
|
|
|
|
c.Next()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.Next()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generateGuestToken 生成游客token,返回token字符串和游客ID
|
|
|
|
|
func (m *AuthMiddleware) generateGuestToken(openid ...string) (string, string, error) {
|
|
|
|
|
guestId := generateGuestId()
|
|
|
|
|
o := ""
|
|
|
|
|
fmt.Println("have openid ?,", openid)
|
|
|
|
|
if len(openid) != 0 {
|
|
|
|
|
o = openid[0]
|
|
|
|
|
}
|
|
|
|
|
claims := &Claims{
|
|
|
|
|
Username: "",
|
|
|
|
|
GuestId: guestId,
|
|
|
|
|
Openid: o,
|
|
|
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
|
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
|
|
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
|
|
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
|
|
|
Issuer: "cls-system",
|
|
|
|
|
Subject: "guest-access",
|
|
|
|
|
Audience: []string{"cls-api"},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
|
|
tokenString, err := token.SignedString(m.secretKey)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", "", err
|
|
|
|
|
}
|
|
|
|
|
return tokenString, guestId, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GenerateUserToken 生成用户token
|
|
|
|
|
func (m *AuthMiddleware) GenerateUserToken(encryptedPhone string, openid string) (string, error) {
|
|
|
|
|
claims := &Claims{
|
|
|
|
|
Username: encryptedPhone,
|
|
|
|
|
GuestId: "", // 登录用户不需要游客ID
|
|
|
|
|
Openid: openid,
|
|
|
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
|
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * 24 * time.Hour)), // 30天过期
|
|
|
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
|
|
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
|
|
|
|
Issuer: "cls-system",
|
|
|
|
|
Subject: "user-access",
|
|
|
|
|
Audience: []string{"cls-api"},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
|
|
|
|
|
|
|
tokenStr, err := token.SignedString(m.secretKey)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", err
|
|
|
|
|
}
|
|
|
|
|
err = m.redis.Set(context.Background(), encryptedPhone, tokenStr, time.Now().Add(30*24*time.Hour).Sub(time.Now())).Err()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", err
|
|
|
|
|
}
|
|
|
|
|
return tokenStr, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// validateToken 验证token
|
|
|
|
|
func (m *AuthMiddleware) validateToken(tokenString string) (*Claims, error) {
|
|
|
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
|
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
|
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
|
|
|
}
|
|
|
|
|
return m.secretKey, nil
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, fmt.Errorf("解析token失败: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
claims, ok := token.Claims.(*Claims)
|
|
|
|
|
|
|
|
|
|
if ok && token.Valid {
|
|
|
|
|
if claims.Username != "" {
|
|
|
|
|
result, _ := m.redis.Get(context.Background(), claims.Username).Result()
|
|
|
|
|
if result != tokenString {
|
|
|
|
|
return nil, errors.New("用户token失效")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return claims, nil
|
|
|
|
|
}
|
|
|
|
|
return nil, fmt.Errorf("无效的token")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// extractToken 从请求头中提取token
|
|
|
|
|
func (m *AuthMiddleware) ExtractToken(c *gin.Context) string {
|
|
|
|
|
auth := c.GetHeader(m.config.TokenKey)
|
|
|
|
|
|
|
|
|
|
if auth == "" {
|
|
|
|
|
return ""
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const prefix = "Bearer "
|
|
|
|
|
if len(auth) <= len(prefix) || auth[:len(prefix)] != prefix {
|
|
|
|
|
return ""
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return auth[len(prefix):]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// shouldSkip 检查是否需要跳过认证
|
|
|
|
|
func (m *AuthMiddleware) shouldSkip(path string) bool {
|
|
|
|
|
for _, skipPath := range m.config.SkipPaths {
|
|
|
|
|
if path == skipPath {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generateGuestId 生成游客ID
|
|
|
|
|
func generateGuestId() string {
|
|
|
|
|
return "G" + time.Now().Format("060102150405") + randomString(6)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// randomString 生成随机字符串
|
|
|
|
|
func randomString(n int) string {
|
|
|
|
|
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
|
|
|
b := make([]byte, n)
|
|
|
|
|
for i := range b {
|
|
|
|
|
b[i] = letters[time.Now().UnixNano()%int64(len(letters))]
|
|
|
|
|
}
|
|
|
|
|
return string(b)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DecodeToken 从 gin.Context 中解析 token 并返回原始的加密手机号
|
|
|
|
|
func (m *AuthMiddleware) DecodeToken(c *gin.Context) (string, error) {
|
|
|
|
|
// 从请求头中提取 token
|
|
|
|
|
tokenString := m.ExtractToken(c)
|
|
|
|
|
if tokenString == "" {
|
|
|
|
|
return "", fmt.Errorf("未找到有效的 token")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 解析 token
|
|
|
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
|
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
|
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
|
|
|
}
|
|
|
|
|
return m.secretKey, nil
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("解析 token 失败: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 验证 token 并获取 claims
|
|
|
|
|
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
|
|
|
|
// 返回原始的加密手机号
|
|
|
|
|
return claims.Username, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return "", fmt.Errorf("无效的 token")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetSecretKey 获取密钥
|
|
|
|
|
func (m *AuthMiddleware) GetSecretKey() []byte {
|
|
|
|
|
return m.secretKey
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DecodeTokenStatic 静态方法,不需要 AuthMiddleware 实例,直接解析 token 字符串
|
|
|
|
|
func DecodeTokenStatic(tokenString string, secretKey []byte) (string, error) {
|
|
|
|
|
// 移除 Bearer 前缀(如果有)
|
|
|
|
|
if len(tokenString) > 7 && tokenString[:7] == "Bearer " {
|
|
|
|
|
tokenString = tokenString[7:]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 解析 token
|
|
|
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
|
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
|
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
|
|
|
}
|
|
|
|
|
return secretKey, nil
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("解析 token 失败: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 验证 token 并获取 claims
|
|
|
|
|
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
|
|
|
|
// 返回原始的加密手机号
|
|
|
|
|
return claims.Username, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return "", fmt.Errorf("无效的 token")
|
|
|
|
|
}
|