You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

301 lines
7.9 KiB
Go

3 weeks ago
package auth
import (
"cls-server/internal/domain/admin"
"context"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"github.com/redis/go-redis/v9"
"net/http"
"time"
)
const (
JWTClaimsKey = "JWT_CLAIMS"
)
// Claims JWT声明
type Claims struct {
Id uint64 `json:"id"`
Username string `json:"username"`
Roles []string `json:"roles"`
Phone string `json:"phone"`
jwt.RegisteredClaims
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
config *Config
secretKey []byte
redis redis.Cmdable
}
// NewAuthMiddleware 创建认证中间件
func NewAuthMiddleware(secretKey string, config *Config, redis redis.Cmdable) *AuthMiddleware {
if config == nil {
config = DefaultConfig()
}
return &AuthMiddleware{
config: config,
secretKey: []byte(secretKey),
redis: redis,
}
}
// Handle 处理认证逻辑
func (m *AuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
c.Next()
return
}
// 获取token
token := m.ExtractToken(c)
// 如果没有token创建临时游客token
if token == "" {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
// 验证已有token
_, err := m.validateToken(token)
if err != nil {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.Next()
}
}
// GenerateUserToken 生成用户token
func (m *AuthMiddleware) GenerateUserToken(admin *admin.Admin) (string, error) {
claims := &Claims{
Id: admin.Id,
Phone: admin.Phone,
Roles: admin.Roles,
Username: admin.Username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * 24 * time.Hour)), // 30天过期
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "cls-system",
Subject: "user-access",
Audience: []string{"cls-api"},
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString(m.secretKey)
if err != nil {
return "", err
}
err = m.redis.Set(context.Background(), admin.Username, tokenStr, time.Now().Add(30*24*time.Hour).Sub(time.Now())).Err()
if err != nil {
return "", err
}
return tokenStr, nil
}
// validateToken 验证token
func (m *AuthMiddleware) validateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return nil, fmt.Errorf("解析token失败: %v", err)
}
claims, ok := token.Claims.(*Claims)
if ok && token.Valid {
if claims.Username != "" {
fmt.Println("验证redis")
fmt.Println(claims.Username)
result, _ := m.redis.Get(context.Background(), claims.Username).Result()
fmt.Println("从redis中获取的值", result)
if result != tokenString {
return nil, errors.New("用户token失效")
}
}
return claims, nil
}
return nil, fmt.Errorf("无效的token")
}
// extractToken 从请求头中提取token
func (m *AuthMiddleware) ExtractToken(c *gin.Context) string {
auth := c.GetHeader(m.config.TokenKey)
fmt.Printf("Authorization头: %s\n", auth)
if auth == "" {
return ""
}
const prefix = "Bearer "
if len(auth) <= len(prefix) || auth[:len(prefix)] != prefix {
fmt.Println("token格式不正确缺少Bearer前缀")
return ""
}
return auth[len(prefix):]
}
// shouldSkip 检查是否需要跳过认证
func (m *AuthMiddleware) shouldSkip(path string) bool {
for _, skipPath := range m.config.SkipPaths {
if path == skipPath {
return true
}
}
return false
}
// randomString 生成随机字符串
func randomString(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[time.Now().UnixNano()%int64(len(letters))]
}
return string(b)
}
// DecodeToken 从 gin.Context 中解析 token 并返回用户名
func (m *AuthMiddleware) DecodeToken(c *gin.Context) (string, error) {
// 从请求头中提取 token
tokenString := m.ExtractToken(c)
if tokenString == "" {
return "", fmt.Errorf("未找到有效的 token")
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}
// GetSecretKey 获取密钥
func (m *AuthMiddleware) GetSecretKey() []byte {
return m.secretKey
}
// DecodeTokenStatic 静态方法,不需要 AuthMiddleware 实例,直接解析 token 字符串
func DecodeTokenStatic(tokenString string, secretKey []byte) (string, error) {
// 移除 Bearer 前缀(如果有)
if len(tokenString) > 7 && tokenString[:7] == "Bearer " {
tokenString = tokenString[7:]
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}
// VerifyPhoneValid 验证手机号是否合法(中国大陆手机号)
func (m *AuthMiddleware) VerifyPhoneValid(phone string) error {
if phone == "" {
return errors.New("手机号不能为空")
}
// 1. 检查长度是否为11位
if len(phone) != 11 {
return errors.New("手机号必须为11位")
}
// 2. 检查是否全为数字
for _, c := range phone {
if c < '0' || c > '9' {
return errors.New("手机号只能包含数字")
}
}
// 3. 检查手机号前三位是否符合运营商规则
prefix := phone[:3]
// 移动号段:
// 134-139, 150-153, 157-159, 182-184, 187-188, 147, 178, 195, 198
// 165, 172, 185-186, 196, 197, 198
mobilePrefixes := map[string]bool{
"134": true, "135": true, "136": true, "137": true, "138": true, "139": true,
"150": true, "151": true, "152": true, "153": true, "157": true, "158": true,
"159": true, "182": true, "183": true, "184": true, "187": true, "188": true,
"147": true, "178": true, "195": true, "198": true, "165": true, "172": true,
"185": true, "186": true, "196": true, "197": true,
}
// 联通号段:
// 130-132, 155-156, 185-186, 176, 145, 166, 175, 171
unicomPrefixes := map[string]bool{
"130": true, "131": true, "132": true, "155": true, "156": true, "185": true,
"186": true, "176": true, "145": true, "166": true, "175": true, "171": true,
}
// 电信号段:
// 133, 153, 180-181, 189, 177, 173, 149, 191, 193, 199
telecomPrefixes := map[string]bool{
"133": true, "153": true, "180": true, "181": true, "189": true, "177": true,
"173": true, "149": true, "191": true, "193": true, "199": true,
}
// 广电号段:
// 192
radioPrefixes := map[string]bool{
"192": true,
}
// 虚拟运营商号段:
// 170, 171
virtualPrefixes := map[string]bool{
"170": true, "171": true,
}
// 检查是否属于任一运营商的号段
if !mobilePrefixes[prefix] && !unicomPrefixes[prefix] &&
!telecomPrefixes[prefix] && !radioPrefixes[prefix] &&
!virtualPrefixes[prefix] {
return errors.New("非法的手机号段")
}
return nil
}