diff --git a/src/Services/AccessControlService.cs b/src/Services/AccessControlService.cs index aeb16e4..cb235f9 100644 --- a/src/Services/AccessControlService.cs +++ b/src/Services/AccessControlService.cs @@ -2,8 +2,6 @@ using Octobot.Extensions; using Remora.Discord.API.Abstractions.Objects; using Remora.Discord.API.Abstractions.Rest; -using Remora.Discord.Commands.Conditions; -using Remora.Discord.Commands.Results; using Remora.Rest.Core; using Remora.Results; @@ -13,32 +11,30 @@ public sealed class AccessControlService { private readonly GuildDataService _data; private readonly IDiscordRestGuildAPI _guildApi; - private readonly RequireDiscordPermissionCondition _permission; private readonly IDiscordRestUserAPI _userApi; - public AccessControlService(GuildDataService data, IDiscordRestGuildAPI guildApi, IDiscordRestUserAPI userApi, - RequireDiscordPermissionCondition permission) + public AccessControlService(GuildDataService data, IDiscordRestGuildAPI guildApi, IDiscordRestUserAPI userApi) { _data = data; _guildApi = guildApi; _userApi = userApi; - _permission = permission; } - private async Task> CheckPermissionAsync(GuildData data, Snowflake memberId, IGuildMember member, - DiscordPermission permission, CancellationToken ct = default) + private static bool CheckPermission(IEnumerable roles, GuildData data, Snowflake memberId, + IGuildMember member, + DiscordPermission permission) { var moderatorRole = GuildSettings.ModeratorRole.Get(data.Settings); - var result = await _permission.CheckAsync(new RequireDiscordPermissionAttribute([permission]), member, ct); - - if (result.Error is not null and not PermissionDeniedError) + if (!moderatorRole.Empty() && data.GetOrCreateMemberData(memberId).Roles.Contains(moderatorRole.Value)) { - return Result.FromError(result); + return true; } - var hasPermission = result.IsSuccess; - return hasPermission || (!moderatorRole.Empty() && - data.GetOrCreateMemberData(memberId).Roles.Contains(moderatorRole.Value)); + return roles + .Where(r => member.Roles.Contains(r.ID)) + .Any(r => + r.Permissions.HasPermission(permission) + ); } /// @@ -67,30 +63,35 @@ public sealed class AccessControlService return Result.FromSuccess($"UserCannot{action}Themselves".Localized()); } - var botResult = await _userApi.GetCurrentUserAsync(ct); - if (!botResult.IsDefined(out var bot)) - { - return Result.FromError(botResult); - } - var guildResult = await _guildApi.GetGuildAsync(guildId, ct: ct); if (!guildResult.IsDefined(out var guild)) { return Result.FromError(guildResult); } - var targetMemberResult = await _guildApi.GetGuildMemberAsync(guildId, targetId, ct); - if (!targetMemberResult.IsDefined(out var targetMember)) + if (interacterId == guild.OwnerID) { return Result.FromSuccess(null); } + var botResult = await _userApi.GetCurrentUserAsync(ct); + if (!botResult.IsDefined(out var bot)) + { + return Result.FromError(botResult); + } + var botMemberResult = await _guildApi.GetGuildMemberAsync(guildId, bot.ID, ct); if (!botMemberResult.IsDefined(out var botMember)) { return Result.FromError(botMemberResult); } + var targetMemberResult = await _guildApi.GetGuildMemberAsync(guildId, targetId, ct); + if (!targetMemberResult.IsDefined(out var targetMember)) + { + return Result.FromSuccess(null); + } + var rolesResult = await _guildApi.GetGuildRolesAsync(guildId, ct); if (!rolesResult.IsDefined(out var roles)) { @@ -110,18 +111,14 @@ public sealed class AccessControlService var data = await _data.GetData(guildId, ct); - var permissionResult = await CheckPermissionAsync(data, interacterId.Value, interacter, + var hasPermission = CheckPermission(roles, data, interacterId.Value, interacter, action switch { "Ban" => DiscordPermission.BanMembers, "Kick" => DiscordPermission.KickMembers, "Mute" or "Unmute" => DiscordPermission.ModerateMembers, _ => throw new Exception() - }, ct); - if (!permissionResult.IsDefined(out var hasPermission)) - { - return Result.FromError(permissionResult); - } + }); return hasPermission ? CheckInteractions(action, guild, roles, targetMember, botMember, interacter) @@ -137,11 +134,6 @@ public sealed class AccessControlService return new ArgumentNullError(nameof(targetMember.User)); } - if (!interacter.User.IsDefined(out var interacterUser)) - { - return new ArgumentNullError(nameof(interacter.User)); - } - if (botMember.User == targetMember.User) { return Result.FromSuccess($"UserCannot{action}Bot".Localized()); @@ -161,11 +153,6 @@ public sealed class AccessControlService return Result.FromSuccess($"BotCannot{action}Target".Localized()); } - if (interacterUser.ID == guild.OwnerID) - { - return Result.FromSuccess(null); - } - var interacterRoles = roles.Where(r => interacter.Roles.Contains(r.ID)); var targetInteracterRoleDiff = targetRoles.MaxOrDefault(r => r.Position) - interacterRoles.MaxOrDefault(r => r.Position);