package auth import ( "cls/internal/application/user" "cls/pkg/captchafx" "cls/pkg/logger" "cls/pkg/sms" "context" "encoding/json" "errors" "fmt" "github.com/redis/go-redis/v9" images "github.com/wenlng/go-captcha-assets/resources/images_v2" "github.com/wenlng/go-captcha-assets/resources/tiles" "github.com/wenlng/go-captcha/v2/slide" "math/rand" "sync" "time" ) // RateLimiter 简单的限流器 type RateLimiter struct { mu sync.Mutex records map[string]time.Time ttl time.Duration } // NewRateLimiter 创建一个新的限流器 func NewRateLimiter(ttl time.Duration) *RateLimiter { return &RateLimiter{ records: make(map[string]time.Time), ttl: ttl, } } // Allow 检查是否允许请求 func (r *RateLimiter) Allow(key string) (bool, time.Duration) { r.mu.Lock() defer r.mu.Unlock() // 清理过期记录 now := time.Now() for k, t := range r.records { if now.Sub(t) > r.ttl { delete(r.records, k) } } // 检查是否存在记录 if lastTime, exists := r.records[key]; exists { elapsed := now.Sub(lastTime) if elapsed < r.ttl { return false, r.ttl - elapsed } } // 更新记录 r.records[key] = now return true, 0 } type CaptchaService struct { rateLimiter *RateLimiter rdb redis.Cmdable captcha *captchafx.Captcha sms sms.IsmsService log logger.Logger userService *user.UserService } type SlideVerifyResp struct { ThumbX int `json:"thumbX"` ThumbY int `json:"thumbY"` ThumbWidth int `json:"thumbWidth"` ThumbHeight int `json:"thumbHeight"` Image string `json:"image"` Thumb string `json:"thumb"` } // NewCaptchaService 创建验证码服务 func NewCaptchaService(rdb redis.Cmdable, captcha *captchafx.Captcha, log logger.New, sms sms.IsmsService, userService *user.UserService) *CaptchaService { // 创建自定义限流器,设置60秒的限流时间 rateLimiter := NewRateLimiter(60 * time.Second) return &CaptchaService{ rdb: rdb, log: log("cls:application:service:auth"), sms: sms, captcha: captcha, userService: userService, rateLimiter: rateLimiter, } } // GenerateSlideVerify 生成滑动验证码 func (c *CaptchaService) GenerateSlideVerify(phone string) (*SlideVerifyResp, error) { builder := slide.NewBuilder() // background images imgs, err := images.GetImages() if err != nil { c.log.Error(err) return nil, err } graphs, err := tiles.GetTiles() if err != nil { c.log.Error(err) return nil, err } var newGraphs = make([]*slide.GraphImage, 0, len(graphs)) for i := 0; i < len(graphs); i++ { graph := graphs[i] newGraphs = append(newGraphs, &slide.GraphImage{ OverlayImage: graph.OverlayImage, MaskImage: graph.MaskImage, ShadowImage: graph.ShadowImage, }) } // set resources builder.SetResources( slide.WithGraphImages(newGraphs), slide.WithBackgrounds(imgs), ) slideCapt := builder.Make() captData, err := slideCapt.Generate() if err != nil { c.log.Error(err) return nil, err } dotData := captData.GetData() if dotData == nil { c.log.Error("生成滑动验证码数据失败") return nil, errors.New("生成失败") } var mBase64, tBase64 string mBase64, err = captData.GetMasterImage().ToBase64() if err != nil { c.log.Error(err) return nil, err } tBase64, err = captData.GetTileImage().ToBase64() if err != nil { c.log.Error(err) return nil, err } // 生成验证token token := fmt.Sprintf("slide_%d", rand.Int63()) // 保存验证信息到Redis verifyInfo := map[string]interface{}{ "phone": phone, "x": dotData.X, "y": dotData.Y, "timestamp": time.Now().Unix(), } verifyData, _ := json.Marshal(verifyInfo) err = c.rdb.Set(context.Background(), token, string(verifyData), 1*time.Minute).Err() if err != nil { c.log.Error("保存验证信息失败:", err) return nil, errors.New("生成验证码失败") } return &SlideVerifyResp{ ThumbX: dotData.X, ThumbY: dotData.Y, ThumbWidth: dotData.Width, ThumbHeight: dotData.Height, Image: mBase64, Thumb: tBase64, }, nil } // VerifySlide 验证滑动位置 func (c *CaptchaService) VerifySlide(token string, x int) (bool, error) { if token == "" { return false, errors.New("验证token不能为空") } // 从Redis获取验证信息 verifyData, err := c.rdb.Get(context.Background(), token).Result() if err != nil { return false, errors.New("验证信息已过期") } // 解析验证信息 var verifyInfo map[string]interface{} err = json.Unmarshal([]byte(verifyData), &verifyInfo) if err != nil { return false, errors.New("验证信息无效") } // 验证滑动位置 correctX := int(verifyInfo["x"].(float64)) timestamp := int64(verifyInfo["timestamp"].(float64)) // 检查是否已过期(5分钟) if time.Now().Unix()-timestamp > 300 { c.rdb.Del(context.Background(), token) return false, errors.New("验证已过期") } // 验证位置是否正确(允许误差范围) if abs(x-correctX) <= 5 { // 验证成功后删除验证信息 c.rdb.Del(context.Background(), token) return true, nil } return false, nil } // abs 获取绝对值 func abs(n int) int { if n < 0 { return -n } return n } // GenerateSmsCaptcha 生成短信验证码 func (c *CaptchaService) GenerateSmsCaptcha(username string, phone string) (*SmsCaptchaResp, error) { if username == "" { return nil, errors.New("用户标识不能为空") } if phone == "" { return nil, errors.New("手机号不能为空") } // 1. 验证手机号格式 if err := c.VerifyPhoneValid(phone); err != nil { return nil, err } // 2. 检查频率限制 t, err := c.checkLimit("sms-captcha:" + phone) if err != nil { c.log.Warnf("%s: 短信频率过高", phone) return &SmsCaptchaResp{ ResetAfter: int(t), }, errors.New("发送过于频繁,请稍后再试") } // 3. 生成验证码 code := c.generateSmsCode() // 4. 存储验证码 key := c.getSmsVerifyKey(phone) err = c.rdb.Set(context.Background(), key, code, 5*time.Minute).Err() if err != nil { c.log.Error("存储验证码失败:", err) return nil, errors.New("生成验证码失败") } fmt.Println("key:", key) // 5. 发送短信 if err = c.sms.Send(code, phone); err != nil { c.log.Error("发送短信失败:", err) // 发送失败时删除验证码 _ = c.rdb.Del(context.Background(), key).Err() return nil, errors.New("发送短信失败") } return &SmsCaptchaResp{ ResetAfter: int(t), }, nil } // generateSmsCode 生成6位数字验证码 func (c *CaptchaService) generateSmsCode() string { return fmt.Sprintf("%06d", rand.Intn(1000000)) } // VerifySmsCaptcha 验证短信验证码 func (c *CaptchaService) VerifySmsCaptcha(username, phone, captchaCode string) error { if username == "" || phone == "" || captchaCode == "" { return errors.New("参数不完整") } // 1. 获取存储的验证码 key := c.getSmsVerifyKey(phone) fmt.Println("校验key", key) storedCode, err := c.rdb.Get(context.Background(), key).Result() if err == redis.Nil { return errors.New("验证码已过期或不存在") } if err != nil { c.log.Error("获取验证码失败:", err) return errors.New("验证失败") } // 2. 验证码比对 if storedCode != captchaCode { return errors.New("验证码错误") } // 3. 验证成功后删除验证码 err = c.rdb.Del(context.Background(), key).Err() if err != nil { c.log.Error("删除验证码失败:", err) } return nil } // getSmsVerifyKey 生成短信验证码的Redis键 func (c *CaptchaService) getSmsVerifyKey(phone string) string { return fmt.Sprintf("sms-captcha:%s", phone) } // checkLimit 检查频率限制 func (c *CaptchaService) checkLimit(key string) (time.Duration, error) { // 检查 rateLimiter 是否已初始化 if c.rateLimiter == nil { c.log.Warn("限流器未初始化,跳过频率限制检查") // 如果限流器未初始化,默认允许通过 return 0, nil } // 尝试获取限流结果 allowed, retryAfter := c.rateLimiter.Allow(key) if !allowed { return retryAfter, errors.New("操作过于频繁,请稍后再试") } return 0, nil } // VerifyPhoneValid 验证手机号是否合法(中国大陆手机号) func (c *CaptchaService) VerifyPhoneValid(phone string) error { if phone == "" { return errors.New("手机号不能为空") } // 1. 检查长度是否为11位 if len(phone) != 11 { return errors.New("手机号必须为11位") } // 2. 检查是否全为数字 for _, c := range phone { if c < '0' || c > '9' { return errors.New("手机号只能包含数字") } } // 3. 检查手机号前三位是否符合运营商规则 prefix := phone[:3] // 移动号段: // 134-139, 150-153, 157-159, 182-184, 187-188, 147, 178, 195, 198 // 165, 172, 185-186, 196, 197, 198 mobilePrefixes := map[string]bool{ "134": true, "135": true, "136": true, "137": true, "138": true, "139": true, "150": true, "151": true, "152": true, "153": true, "157": true, "158": true, "159": true, "182": true, "183": true, "184": true, "187": true, "188": true, "147": true, "178": true, "195": true, "198": true, "165": true, "172": true, "185": true, "186": true, "196": true, "197": true, } // 联通号段: // 130-132, 155-156, 185-186, 176, 145, 166, 175, 171 unicomPrefixes := map[string]bool{ "130": true, "131": true, "132": true, "155": true, "156": true, "185": true, "186": true, "176": true, "145": true, "166": true, "175": true, "171": true, } // 电信号段: // 133, 153, 180-181, 189, 177, 173, 149, 191, 193, 199 telecomPrefixes := map[string]bool{ "133": true, "153": true, "180": true, "181": true, "189": true, "177": true, "173": true, "149": true, "191": true, "193": true, "199": true, } // 广电号段: // 192 radioPrefixes := map[string]bool{ "192": true, } // 虚拟运营商号段: // 170, 171 virtualPrefixes := map[string]bool{ "170": true, "171": true, } // 检查是否属于任一运营商的号段 if !mobilePrefixes[prefix] && !unicomPrefixes[prefix] && !telecomPrefixes[prefix] && !radioPrefixes[prefix] && !virtualPrefixes[prefix] { return errors.New("非法的手机号段") } return nil }