using AspNetCoreRateLimit; using Infrastructure; using Infrastructure.Attribute; using Mapster; using Microsoft.Extensions.Options; using ZR.Service; using ZR.ServiceCore.Model; using ZR.ServiceCore.Services.IService; using IpRateLimitPolicy = ZR.ServiceCore.Model.IpRateLimitPolicy; using RateLimitRule = ZR.ServiceCore.Model.RateLimitRule; namespace ZR.ServiceCore.Services; [AppService(ServiceType = typeof(ISysLoginLimitService), ServiceLifetime = LifeTime.Transient)] public class SysLoginLimitService : BaseService, ISysLoginLimitService { private readonly IpRateLimitOptions _options; private readonly IIpPolicyStore _ipPolicyStore; private readonly IRateLimitRuleService _rateLimitRuleService; private readonly IIpRateLimitPolicyService _ipRateLimitPolicyService; public SysLoginLimitService(IIpPolicyStore ipPolicyStore, IOptions optionsAccessor, IIpRateLimitPolicyService ipRateLimitPolicyService, IRateLimitRuleService rateLimitRuleService) { _ipPolicyStore = ipPolicyStore; _ipRateLimitPolicyService = ipRateLimitPolicyService; _rateLimitRuleService = rateLimitRuleService; _options = optionsAccessor.Value; } public async Task AddSysLoginLimitAsync(SysLoginLimit sysLoginLimit) { try { await Context.Ado.BeginTranAsync(); var rateLimitRule = new RateLimitRule { Endpoint = "*", Period = "1s", Limit = 0, Flag = '1' }; var ipRateLimitPolicy = new IpRateLimitPolicy { Ip = sysLoginLimit.Ip, Flag = '1', Rules = new List { rateLimitRule } }; var existIpRateLimitPolicy = await _ipRateLimitPolicyService.Queryable() .Where(it => it.Ip == sysLoginLimit.Ip) .FirstAsync(); if (existIpRateLimitPolicy != null) { rateLimitRule.IpRateLimitPolicyId = existIpRateLimitPolicy.Id; var rateLimitRuleStore = await _rateLimitRuleService .Storageable(rateLimitRule) .WhereColumns(it => new { it.IpRateLimitPolicyId, it.Endpoint, it.Period, it.Limit }) .ToStorageAsync(); await rateLimitRuleStore.AsInsertable.ExecuteReturnSnowflakeIdAsync(); await rateLimitRuleStore.AsUpdateable .IgnoreColumns(it => it.Id) .ExecuteCommandAsync(); if (existIpRateLimitPolicy.Flag != '1') { await _ipRateLimitPolicyService.Updateable(new IpRateLimitPolicy { Id = existIpRateLimitPolicy.Id, Flag = '1' }).UpdateColumns(it => it.Flag) .ExecuteCommandAsync(); } } else { await _ipRateLimitPolicyService.InsertNav(ipRateLimitPolicy) .Include(it => it.Rules) .ExecuteCommandAsync(); } var res = await _ipRateLimitPolicyService.Queryable() .Includes(it => it.Rules.Where(r => r.Flag == '1').ToList()) .Where(it => it.Flag == '1') .ToListAsync(); await _ipPolicyStore.RemoveAsync(_options.IpPolicyPrefix); await _ipPolicyStore.SeedAsync(); var pol = await _ipPolicyStore.GetAsync(_options.IpPolicyPrefix); pol.IpRules.AddRange(res.Adapt>()); await _ipPolicyStore.SetAsync(_options.IpPolicyPrefix, pol); await Updateable(new SysLoginLimit { Id = sysLoginLimit.Id, Flag = '1' }).UpdateColumns(it => it.Flag) .ExecuteCommandAsync(); await Context.Ado.CommitTranAsync(); return true; } catch (Exception e) { await Context.Ado.RollbackTranAsync(); var res = await _ipRateLimitPolicyService.Queryable() .Includes(it => it.Rules.Where(r => r.Flag == '1').ToList()) .Where(it => it.Flag == '1') .ToListAsync(); await _ipPolicyStore.RemoveAsync(_options.IpPolicyPrefix); await _ipPolicyStore.SeedAsync(); var pol = await _ipPolicyStore.GetAsync(_options.IpPolicyPrefix); pol.IpRules.AddRange(res.Adapt>()); await _ipPolicyStore.SetAsync(_options.IpPolicyPrefix, pol); throw; } } public async Task RemoveSysLoginLimitAsync(SysLoginLimit sysLoginLimit) { try { await Context.Ado.BeginTranAsync(); // 查询是否存在此IP的规则 var existIpRateLimitPolicy = await _ipRateLimitPolicyService.Queryable() .Where(it => it.Ip == sysLoginLimit.Ip) .FirstAsync(); if (existIpRateLimitPolicy != null) { var rateLimitRule = new RateLimitRule { Endpoint = "*", Period = "1s", Limit = 0, Flag = '0', IpRateLimitPolicyId = existIpRateLimitPolicy.Id }; var rateLimitRuleStore = await _rateLimitRuleService .Storageable(rateLimitRule) .WhereColumns(it => new { it.IpRateLimitPolicyId, it.Endpoint, it.Period, it.Limit }) .ToStorageAsync(); await rateLimitRuleStore.AsInsertable.ExecuteReturnSnowflakeIdAsync(); await rateLimitRuleStore.AsUpdateable .IgnoreColumns(it => it.Id) .ExecuteCommandAsync(); var res = await _ipRateLimitPolicyService.Queryable() .Includes(it => it.Rules.Where(r => r.Flag == '1').ToList()) .Where(it => it.Flag == '1') .ToListAsync(); await _ipPolicyStore.RemoveAsync(_options.IpPolicyPrefix); await _ipPolicyStore.SeedAsync(); var pol = await _ipPolicyStore.GetAsync(_options.IpPolicyPrefix); pol.IpRules.AddRange(res.Adapt>()); await _ipPolicyStore.SetAsync(_options.IpPolicyPrefix, pol); await Updateable(new SysLoginLimit { Id = sysLoginLimit.Id, Flag = '0', ErrorCount = 0 }).UpdateColumns(it => new { it.Flag, it.ErrorCount }) .ExecuteCommandAsync(); await Context.Ado.CommitTranAsync(); return true; } throw new CustomException("不存在此IP的规则"); } catch (Exception e) { await Context.Ado.RollbackTranAsync(); var res = await _ipRateLimitPolicyService.Queryable() .Includes(it => it.Rules.Where(r => r.Flag == '1').ToList()) .Where(it => it.Flag == '1') .ToListAsync(); await _ipPolicyStore.RemoveAsync(_options.IpPolicyPrefix); await _ipPolicyStore.SeedAsync(); var pol = await _ipPolicyStore.GetAsync(_options.IpPolicyPrefix); pol.IpRules.AddRange(res.Adapt>()); await _ipPolicyStore.SetAsync(_options.IpPolicyPrefix, pol); throw; } } }