You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cls/internal/application/auth/captcha_service.go

400 lines
10 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package auth
import (
"cls/internal/application/user"
"cls/pkg/captchafx"
"cls/pkg/logger"
"cls/pkg/sms"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/redis/go-redis/v9"
images "github.com/wenlng/go-captcha-assets/resources/images_v2"
"github.com/wenlng/go-captcha-assets/resources/tiles"
"github.com/wenlng/go-captcha/v2/slide"
"math/rand"
"sync"
"time"
)
// RateLimiter 简单的限流器
type RateLimiter struct {
mu sync.Mutex
records map[string]time.Time
ttl time.Duration
}
// NewRateLimiter 创建一个新的限流器
func NewRateLimiter(ttl time.Duration) *RateLimiter {
return &RateLimiter{
records: make(map[string]time.Time),
ttl: ttl,
}
}
// Allow 检查是否允许请求
func (r *RateLimiter) Allow(key string) (bool, time.Duration) {
r.mu.Lock()
defer r.mu.Unlock()
// 清理过期记录
now := time.Now()
for k, t := range r.records {
if now.Sub(t) > r.ttl {
delete(r.records, k)
}
}
// 检查是否存在记录
if lastTime, exists := r.records[key]; exists {
elapsed := now.Sub(lastTime)
if elapsed < r.ttl {
return false, r.ttl - elapsed
}
}
// 更新记录
r.records[key] = now
return true, 0
}
type CaptchaService struct {
rateLimiter *RateLimiter
rdb redis.Cmdable
captcha *captchafx.Captcha
sms sms.IsmsService
log logger.Logger
userService *user.UserService
}
type SlideVerifyResp struct {
ThumbX int `json:"thumbX"`
ThumbY int `json:"thumbY"`
ThumbWidth int `json:"thumbWidth"`
ThumbHeight int `json:"thumbHeight"`
Image string `json:"image"`
Thumb string `json:"thumb"`
}
// NewCaptchaService 创建验证码服务
func NewCaptchaService(rdb redis.Cmdable, captcha *captchafx.Captcha, log logger.New, sms sms.IsmsService, userService *user.UserService) *CaptchaService {
// 创建自定义限流器设置60秒的限流时间
rateLimiter := NewRateLimiter(60 * time.Second)
return &CaptchaService{
rdb: rdb,
log: log("cls:application:service:auth"),
sms: sms,
captcha: captcha,
userService: userService,
rateLimiter: rateLimiter,
}
}
// GenerateSlideVerify 生成滑动验证码
func (c *CaptchaService) GenerateSlideVerify(phone string) (*SlideVerifyResp, error) {
builder := slide.NewBuilder()
// background images
imgs, err := images.GetImages()
if err != nil {
c.log.Error(err)
return nil, err
}
graphs, err := tiles.GetTiles()
if err != nil {
c.log.Error(err)
return nil, err
}
var newGraphs = make([]*slide.GraphImage, 0, len(graphs))
for i := 0; i < len(graphs); i++ {
graph := graphs[i]
newGraphs = append(newGraphs, &slide.GraphImage{
OverlayImage: graph.OverlayImage,
MaskImage: graph.MaskImage,
ShadowImage: graph.ShadowImage,
})
}
// set resources
builder.SetResources(
slide.WithGraphImages(newGraphs),
slide.WithBackgrounds(imgs),
)
slideCapt := builder.Make()
captData, err := slideCapt.Generate()
if err != nil {
c.log.Error(err)
return nil, err
}
dotData := captData.GetData()
if dotData == nil {
c.log.Error("生成滑动验证码数据失败")
return nil, errors.New("生成失败")
}
var mBase64, tBase64 string
mBase64, err = captData.GetMasterImage().ToBase64()
if err != nil {
c.log.Error(err)
return nil, err
}
tBase64, err = captData.GetTileImage().ToBase64()
if err != nil {
c.log.Error(err)
return nil, err
}
// 生成验证token
token := fmt.Sprintf("slide_%d", rand.Int63())
// 保存验证信息到Redis
verifyInfo := map[string]interface{}{
"phone": phone,
"x": dotData.X,
"y": dotData.Y,
"timestamp": time.Now().Unix(),
}
verifyData, _ := json.Marshal(verifyInfo)
err = c.rdb.Set(context.Background(), token, string(verifyData), 1*time.Minute).Err()
if err != nil {
c.log.Error("保存验证信息失败:", err)
return nil, errors.New("生成验证码失败")
}
return &SlideVerifyResp{
ThumbX: dotData.X,
ThumbY: dotData.Y,
ThumbWidth: dotData.Width,
ThumbHeight: dotData.Height,
Image: mBase64,
Thumb: tBase64,
}, nil
}
// VerifySlide 验证滑动位置
func (c *CaptchaService) VerifySlide(token string, x int) (bool, error) {
if token == "" {
return false, errors.New("验证token不能为空")
}
// 从Redis获取验证信息
verifyData, err := c.rdb.Get(context.Background(), token).Result()
if err != nil {
return false, errors.New("验证信息已过期")
}
// 解析验证信息
var verifyInfo map[string]interface{}
err = json.Unmarshal([]byte(verifyData), &verifyInfo)
if err != nil {
return false, errors.New("验证信息无效")
}
// 验证滑动位置
correctX := int(verifyInfo["x"].(float64))
timestamp := int64(verifyInfo["timestamp"].(float64))
// 检查是否已过期5分钟
if time.Now().Unix()-timestamp > 300 {
c.rdb.Del(context.Background(), token)
return false, errors.New("验证已过期")
}
// 验证位置是否正确(允许误差范围)
if abs(x-correctX) <= 5 {
// 验证成功后删除验证信息
c.rdb.Del(context.Background(), token)
return true, nil
}
return false, nil
}
// abs 获取绝对值
func abs(n int) int {
if n < 0 {
return -n
}
return n
}
// GenerateSmsCaptcha 生成短信验证码
func (c *CaptchaService) GenerateSmsCaptcha(username string, phone string) (*SmsCaptchaResp, error) {
if username == "" {
return nil, errors.New("用户标识不能为空")
}
if phone == "" {
return nil, errors.New("手机号不能为空")
}
// 1. 验证手机号格式
if err := c.VerifyPhoneValid(phone); err != nil {
return nil, err
}
// 2. 检查频率限制
t, err := c.checkLimit("sms-captcha:" + phone)
if err != nil {
c.log.Warnf("%s: 短信频率过高", phone)
return &SmsCaptchaResp{
ResetAfter: int(t),
}, errors.New("发送过于频繁,请稍后再试")
}
// 3. 生成验证码
code := c.generateSmsCode()
// 4. 存储验证码
key := c.getSmsVerifyKey(phone)
err = c.rdb.Set(context.Background(), key, code, 5*time.Minute).Err()
if err != nil {
c.log.Error("存储验证码失败:", err)
return nil, errors.New("生成验证码失败")
}
fmt.Println("key:", key)
// 5. 发送短信
if err = c.sms.Send(code, phone); err != nil {
c.log.Error("发送短信失败:", err)
// 发送失败时删除验证码
_ = c.rdb.Del(context.Background(), key).Err()
return nil, errors.New("发送短信失败")
}
return &SmsCaptchaResp{
ResetAfter: int(t),
}, nil
}
// generateSmsCode 生成6位数字验证码
func (c *CaptchaService) generateSmsCode() string {
return fmt.Sprintf("%06d", rand.Intn(1000000))
}
// VerifySmsCaptcha 验证短信验证码
func (c *CaptchaService) VerifySmsCaptcha(username, phone, captchaCode string) error {
if username == "" || phone == "" || captchaCode == "" {
return errors.New("参数不完整")
}
// 1. 获取存储的验证码
key := c.getSmsVerifyKey(phone)
fmt.Println("校验key", key)
storedCode, err := c.rdb.Get(context.Background(), key).Result()
if err == redis.Nil {
return errors.New("验证码已过期或不存在")
}
if err != nil {
c.log.Error("获取验证码失败:", err)
return errors.New("验证失败")
}
// 2. 验证码比对
if storedCode != captchaCode {
return errors.New("验证码错误")
}
// 3. 验证成功后删除验证码
err = c.rdb.Del(context.Background(), key).Err()
if err != nil {
c.log.Error("删除验证码失败:", err)
}
return nil
}
// getSmsVerifyKey 生成短信验证码的Redis键
func (c *CaptchaService) getSmsVerifyKey(phone string) string {
return fmt.Sprintf("sms-captcha:%s", phone)
}
// checkLimit 检查频率限制
func (c *CaptchaService) checkLimit(key string) (time.Duration, error) {
// 检查 rateLimiter 是否已初始化
if c.rateLimiter == nil {
c.log.Warn("限流器未初始化,跳过频率限制检查")
// 如果限流器未初始化,默认允许通过
return 0, nil
}
// 尝试获取限流结果
allowed, retryAfter := c.rateLimiter.Allow(key)
if !allowed {
return retryAfter, errors.New("操作过于频繁,请稍后再试")
}
return 0, nil
}
// VerifyPhoneValid 验证手机号是否合法(中国大陆手机号)
func (c *CaptchaService) VerifyPhoneValid(phone string) error {
if phone == "" {
return errors.New("手机号不能为空")
}
// 1. 检查长度是否为11位
if len(phone) != 11 {
return errors.New("手机号必须为11位")
}
// 2. 检查是否全为数字
for _, c := range phone {
if c < '0' || c > '9' {
return errors.New("手机号只能包含数字")
}
}
// 3. 检查手机号前三位是否符合运营商规则
prefix := phone[:3]
// 移动号段:
// 134-139, 150-153, 157-159, 182-184, 187-188, 147, 178, 195, 198
// 165, 172, 185-186, 196, 197, 198
mobilePrefixes := map[string]bool{
"134": true, "135": true, "136": true, "137": true, "138": true, "139": true,
"150": true, "151": true, "152": true, "153": true, "157": true, "158": true,
"159": true, "182": true, "183": true, "184": true, "187": true, "188": true,
"147": true, "178": true, "195": true, "198": true, "165": true, "172": true,
"185": true, "186": true, "196": true, "197": true,
}
// 联通号段:
// 130-132, 155-156, 185-186, 176, 145, 166, 175, 171
unicomPrefixes := map[string]bool{
"130": true, "131": true, "132": true, "155": true, "156": true, "185": true,
"186": true, "176": true, "145": true, "166": true, "175": true, "171": true,
}
// 电信号段:
// 133, 153, 180-181, 189, 177, 173, 149, 191, 193, 199
telecomPrefixes := map[string]bool{
"133": true, "153": true, "180": true, "181": true, "189": true, "177": true,
"173": true, "149": true, "191": true, "193": true, "199": true,
}
// 广电号段:
// 192
radioPrefixes := map[string]bool{
"192": true,
}
// 虚拟运营商号段:
// 170, 171
virtualPrefixes := map[string]bool{
"170": true, "171": true,
}
// 检查是否属于任一运营商的号段
if !mobilePrefixes[prefix] && !unicomPrefixes[prefix] &&
!telecomPrefixes[prefix] && !radioPrefixes[prefix] &&
!virtualPrefixes[prefix] {
return errors.New("非法的手机号段")
}
return nil
}