You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

301 lines
7.9 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package auth
import (
"cls-server/internal/domain/admin"
"context"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"github.com/redis/go-redis/v9"
"net/http"
"time"
)
const (
JWTClaimsKey = "JWT_CLAIMS"
)
// Claims JWT声明
type Claims struct {
Id uint64 `json:"id"`
Username string `json:"username"`
Roles []string `json:"roles"`
Phone string `json:"phone"`
jwt.RegisteredClaims
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
config *Config
secretKey []byte
redis redis.Cmdable
}
// NewAuthMiddleware 创建认证中间件
func NewAuthMiddleware(secretKey string, config *Config, redis redis.Cmdable) *AuthMiddleware {
if config == nil {
config = DefaultConfig()
}
return &AuthMiddleware{
config: config,
secretKey: []byte(secretKey),
redis: redis,
}
}
// Handle 处理认证逻辑
func (m *AuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
// 如果是跳过认证的路径,直接继续
if m.shouldSkip(path) {
c.Next()
return
}
// 获取token
token := m.ExtractToken(c)
// 如果没有token创建临时游客token
if token == "" {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
// 验证已有token
_, err := m.validateToken(token)
if err != nil {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.Next()
}
}
// GenerateUserToken 生成用户token
func (m *AuthMiddleware) GenerateUserToken(admin *admin.Admin) (string, error) {
claims := &Claims{
Id: admin.Id,
Phone: admin.Phone,
Roles: admin.Roles,
Username: admin.Username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(30 * 24 * time.Hour)), // 30天过期
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "cls-system",
Subject: "user-access",
Audience: []string{"cls-api"},
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString(m.secretKey)
if err != nil {
return "", err
}
err = m.redis.Set(context.Background(), admin.Username, tokenStr, time.Now().Add(30*24*time.Hour).Sub(time.Now())).Err()
if err != nil {
return "", err
}
return tokenStr, nil
}
// validateToken 验证token
func (m *AuthMiddleware) validateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return nil, fmt.Errorf("解析token失败: %v", err)
}
claims, ok := token.Claims.(*Claims)
if ok && token.Valid {
if claims.Username != "" {
fmt.Println("验证redis")
fmt.Println(claims.Username)
result, _ := m.redis.Get(context.Background(), claims.Username).Result()
fmt.Println("从redis中获取的值", result)
if result != tokenString {
return nil, errors.New("用户token失效")
}
}
return claims, nil
}
return nil, fmt.Errorf("无效的token")
}
// extractToken 从请求头中提取token
func (m *AuthMiddleware) ExtractToken(c *gin.Context) string {
auth := c.GetHeader(m.config.TokenKey)
fmt.Printf("Authorization头: %s\n", auth)
if auth == "" {
return ""
}
const prefix = "Bearer "
if len(auth) <= len(prefix) || auth[:len(prefix)] != prefix {
fmt.Println("token格式不正确缺少Bearer前缀")
return ""
}
return auth[len(prefix):]
}
// shouldSkip 检查是否需要跳过认证
func (m *AuthMiddleware) shouldSkip(path string) bool {
for _, skipPath := range m.config.SkipPaths {
if path == skipPath {
return true
}
}
return false
}
// randomString 生成随机字符串
func randomString(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[time.Now().UnixNano()%int64(len(letters))]
}
return string(b)
}
// DecodeToken 从 gin.Context 中解析 token 并返回用户名
func (m *AuthMiddleware) DecodeToken(c *gin.Context) (string, error) {
// 从请求头中提取 token
tokenString := m.ExtractToken(c)
if tokenString == "" {
return "", fmt.Errorf("未找到有效的 token")
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}
// GetSecretKey 获取密钥
func (m *AuthMiddleware) GetSecretKey() []byte {
return m.secretKey
}
// DecodeTokenStatic 静态方法,不需要 AuthMiddleware 实例,直接解析 token 字符串
func DecodeTokenStatic(tokenString string, secretKey []byte) (string, error) {
// 移除 Bearer 前缀(如果有)
if len(tokenString) > 7 && tokenString[:7] == "Bearer " {
tokenString = tokenString[7:]
}
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return secretKey, nil
})
if err != nil {
return "", fmt.Errorf("解析 token 失败: %v", err)
}
// 验证 token 并获取 claims
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 返回原始的加密手机号
return claims.Username, nil
}
return "", fmt.Errorf("无效的 token")
}
// VerifyPhoneValid 验证手机号是否合法(中国大陆手机号)
func (m *AuthMiddleware) VerifyPhoneValid(phone string) error {
if phone == "" {
return errors.New("手机号不能为空")
}
// 1. 检查长度是否为11位
if len(phone) != 11 {
return errors.New("手机号必须为11位")
}
// 2. 检查是否全为数字
for _, c := range phone {
if c < '0' || c > '9' {
return errors.New("手机号只能包含数字")
}
}
// 3. 检查手机号前三位是否符合运营商规则
prefix := phone[:3]
// 移动号段:
// 134-139, 150-153, 157-159, 182-184, 187-188, 147, 178, 195, 198
// 165, 172, 185-186, 196, 197, 198
mobilePrefixes := map[string]bool{
"134": true, "135": true, "136": true, "137": true, "138": true, "139": true,
"150": true, "151": true, "152": true, "153": true, "157": true, "158": true,
"159": true, "182": true, "183": true, "184": true, "187": true, "188": true,
"147": true, "178": true, "195": true, "198": true, "165": true, "172": true,
"185": true, "186": true, "196": true, "197": true,
}
// 联通号段:
// 130-132, 155-156, 185-186, 176, 145, 166, 175, 171
unicomPrefixes := map[string]bool{
"130": true, "131": true, "132": true, "155": true, "156": true, "185": true,
"186": true, "176": true, "145": true, "166": true, "175": true, "171": true,
}
// 电信号段:
// 133, 153, 180-181, 189, 177, 173, 149, 191, 193, 199
telecomPrefixes := map[string]bool{
"133": true, "153": true, "180": true, "181": true, "189": true, "177": true,
"173": true, "149": true, "191": true, "193": true, "199": true,
}
// 广电号段:
// 192
radioPrefixes := map[string]bool{
"192": true,
}
// 虚拟运营商号段:
// 170, 171
virtualPrefixes := map[string]bool{
"170": true, "171": true,
}
// 检查是否属于任一运营商的号段
if !mobilePrefixes[prefix] && !unicomPrefixes[prefix] &&
!telecomPrefixes[prefix] && !radioPrefixes[prefix] &&
!virtualPrefixes[prefix] {
return errors.New("非法的手机号段")
}
return nil
}