package auth import ( "cls-server/internal/application/crypto" "cls-server/internal/application/user" "cls-server/internal/domain/admin" "cls-server/pkg/captchafx" "cls-server/pkg/logger" "cls-server/pkg/sms" "context" "errors" "fmt" "github.com/redis/go-redis/v9" "math/rand" "sync" "time" ) // RateLimiter 简单的限流器 type RateLimiter struct { mu sync.Mutex records map[string]time.Time ttl time.Duration } // NewRateLimiter 创建一个新的限流器 func NewRateLimiter(ttl time.Duration) *RateLimiter { return &RateLimiter{ records: make(map[string]time.Time), ttl: ttl, } } // Allow 检查是否允许请求 func (r *RateLimiter) Allow(key string) (bool, time.Duration) { r.mu.Lock() defer r.mu.Unlock() // 清理过期记录 now := time.Now() for k, t := range r.records { if now.Sub(t) > r.ttl { delete(r.records, k) } } // 检查是否存在记录 if lastTime, exists := r.records[key]; exists { elapsed := now.Sub(lastTime) if elapsed < r.ttl { return false, r.ttl - elapsed } } // 更新记录 r.records[key] = now return true, 0 } type CaptchaService struct { rateLimiter *RateLimiter rdb redis.Cmdable captcha *captchafx.Captcha sms sms.IsmsService log logger.Logger userService *user.UserService adminRepo admin.AdminRepository phoneEncrypt *crypto.PhoneEncryptionService } type SlideVerifyResp struct { ThumbX int `json:"thumbX"` ThumbY int `json:"thumbY"` ThumbWidth int `json:"thumbWidth"` ThumbHeight int `json:"thumbHeight"` Image string `json:"image"` Thumb string `json:"thumb"` } // NewCaptchaService 创建验证码服务 func NewCaptchaService(rdb redis.Cmdable, captcha *captchafx.Captcha, log logger.New, sms sms.IsmsService, userService *user.UserService, adminRepo admin.AdminRepository, phoneEncrypt *crypto.PhoneEncryptionService) *CaptchaService { // 创建自定义限流器,设置60秒的限流时间 rateLimiter := NewRateLimiter(60 * time.Second) return &CaptchaService{ rdb: rdb, log: log("cls:application:service:auth"), sms: sms, captcha: captcha, userService: userService, rateLimiter: rateLimiter, adminRepo: adminRepo, phoneEncrypt: phoneEncrypt, } } func (c *CaptchaService) GenerateImageCaptcha(username string) (string, error) { if username == "" { return "", errors.New("") } _, b64, anwser, err := c.captcha.Math.Generate() if err != nil { c.log.Error(err) return "", nil } err = c.rdb.Set(context.Background(), fmt.Sprintf("%s-image-captcha", username), anwser, time.Second*60).Err() if err != nil { c.log.Error(err.Error()) return "", nil } return b64, nil } // abs 获取绝对值 func abs(n int) int { if n < 0 { return -n } return n } // GenerateSmsCaptcha 生成短信验证码 func (c *CaptchaService) GenerateSmsCaptcha(username string) (*CaptchaResp, error) { if username == "" { return nil, errors.New("用户标识不能为空") } d, err := c.adminRepo.FindAdminByUsername(username) if err != nil { c.log.Error(err.Error()) return nil, err } phone, err := c.phoneEncrypt.Decrypt(d.Phone) if err != nil { c.log.Error(err.Error()) return nil, err } d = nil t, err := c.checkLimit("sms-captcha:" + phone) if err != nil { c.log.Warnf("%s: 短信频率过高", phone) return &CaptchaResp{ ResetAfter: int(t), }, errors.New("发送过于频繁,请稍后再试") } // 3. 生成验证码 code := c.generateSmsCode() // 4. 存储验证码 key := c.getSmsVerifyKey(username) err = c.rdb.Set(context.Background(), key, code, 5*time.Minute).Err() if err != nil { c.log.Error("存储验证码失败:", err) return nil, errors.New("生成验证码失败") } fmt.Println("key:", key) // 5. 发送短信 if err = c.sms.Send(code, phone); err != nil { c.log.Error("发送短信失败:", err) // 发送失败时删除验证码 _ = c.rdb.Del(context.Background(), key).Err() return nil, errors.New("发送短信失败") } return &CaptchaResp{ ResetAfter: int(t), }, nil } // generateSmsCode 生成6位数字验证码 func (c *CaptchaService) generateSmsCode() string { return fmt.Sprintf("%06d", rand.Intn(1000000)) } // VerifySmsCaptcha 验证短信验证码 func (c *CaptchaService) VerifySmsCaptcha(username, captchaCode string) error { if username == "" || captchaCode == "" { return errors.New("参数不完整") } // 1. 获取存储的验证码 key := c.getSmsVerifyKey(username) fmt.Println("校验key", key) storedCode, err := c.rdb.Get(context.Background(), key).Result() if err == redis.Nil { return errors.New("验证码已过期或不存在") } if err != nil { c.log.Error("获取验证码失败:", err) return errors.New("验证失败") } // 2. 验证码比对 if storedCode != captchaCode { return errors.New("验证码错误") } // 3. 验证成功后删除验证码 err = c.rdb.Del(context.Background(), key).Err() if err != nil { c.log.Error("删除验证码失败:", err) } return nil } // getSmsVerifyKey 生成短信验证码的Redis键 func (c *CaptchaService) getSmsVerifyKey(username string) string { return fmt.Sprintf("%s-sms-captcha", username) } // checkLimit 检查频率限制 func (c *CaptchaService) checkLimit(key string) (time.Duration, error) { // 检查 rateLimiter 是否已初始化 if c.rateLimiter == nil { c.log.Warn("限流器未初始化,跳过频率限制检查") // 如果限流器未初始化,默认允许通过 return 0, nil } // 尝试获取限流结果 allowed, retryAfter := c.rateLimiter.Allow(key) if !allowed { return retryAfter, errors.New("操作过于频繁,请稍后再试") } return 0, nil } // VerifyPhoneValid 验证手机号是否合法(中国大陆手机号) func (c *CaptchaService) VerifyPhoneValid(phone string) error { if phone == "" { return errors.New("手机号不能为空") } // 1. 检查长度是否为11位 if len(phone) != 11 { return errors.New("手机号必须为11位") } // 2. 检查是否全为数字 for _, c := range phone { if c < '0' || c > '9' { return errors.New("手机号只能包含数字") } } // 3. 检查手机号前三位是否符合运营商规则 prefix := phone[:3] // 移动号段: // 134-139, 150-153, 157-159, 182-184, 187-188, 147, 178, 195, 198 // 165, 172, 185-186, 196, 197, 198 mobilePrefixes := map[string]bool{ "134": true, "135": true, "136": true, "137": true, "138": true, "139": true, "150": true, "151": true, "152": true, "153": true, "157": true, "158": true, "159": true, "182": true, "183": true, "184": true, "187": true, "188": true, "147": true, "178": true, "195": true, "198": true, "165": true, "172": true, "185": true, "186": true, "196": true, "197": true, } // 联通号段: // 130-132, 155-156, 185-186, 176, 145, 166, 175, 171 unicomPrefixes := map[string]bool{ "130": true, "131": true, "132": true, "155": true, "156": true, "185": true, "186": true, "176": true, "145": true, "166": true, "175": true, "171": true, } // 电信号段: // 133, 153, 180-181, 189, 177, 173, 149, 191, 193, 199 telecomPrefixes := map[string]bool{ "133": true, "153": true, "180": true, "181": true, "189": true, "177": true, "173": true, "149": true, "191": true, "193": true, "199": true, } // 广电号段: // 192 radioPrefixes := map[string]bool{ "192": true, } // 虚拟运营商号段: // 170, 171 virtualPrefixes := map[string]bool{ "170": true, "171": true, } // 检查是否属于任一运营商的号段 if !mobilePrefixes[prefix] && !unicomPrefixes[prefix] && !telecomPrefixes[prefix] && !radioPrefixes[prefix] && !virtualPrefixes[prefix] { return errors.New("非法的手机号段") } return nil }