diff --git a/command.go b/command.go index 7750e55..d369729 100644 --- a/command.go +++ b/command.go @@ -3,6 +3,7 @@ package main import ( "fmt" "github.com/bwmarrin/discordgo" + "github.com/deckarep/golang-set" "strings" "time" "log" @@ -41,7 +42,7 @@ type Command struct { // for custom commands that go beyond prints and deletions Function func(*discordgo.Session, *discordgo.MessageCreate) - IsOnCooldown bool // don’t set this manually (it’s overwritten anyway) + UsersOnCooldown mapset.Set // don’t set this manually (it’s overwritten anyway) } @@ -54,6 +55,7 @@ func registerCommand(command Command) { if command.IgnoreCase { command.Trigger = strings.ToLower(command.Trigger) } + command.UsersOnCooldown = mapset.NewSet() commands = append(commands, command) } @@ -108,14 +110,14 @@ func evaluateMessage(s *discordgo.Session, m *discordgo.MessageCreate) { Sets command cooldowns if necessary and also clears them again. */ func executeCommand(session *discordgo.Session, message *discordgo.MessageCreate, command Command, commandIndex int) { - if (message.Author.ID == config.AdminID) || // no restrictions for admins - (!command.AdminOnly && !command.IsOnCooldown && - (!command.DMOnly || (getChannel(session.State, message.ChannelID).Type == discordgo.ChannelTypeDM))) { + if isAdmin(message.Author) || // no restrictions for admins + (!command.AdminOnly && (isDM(session, message) || !commands[commandIndex].UsersOnCooldown.Contains(message.Author.ID)) && + (!command.DMOnly || isDM(session, message))) { log.Printf("Executed command %s triggered by user %s", command.Trigger, userToString(message.Author)) - if command.Cooldown > 0 { - commands[commandIndex].IsOnCooldown = true - go removeCooldown(commandIndex) + if command.Cooldown > 0 && !isDM(session, message) && !isAdmin(message.Author) { + commands[commandIndex].UsersOnCooldown.Add(message.Author.ID) + go removeCooldown(commandIndex, message.Author.ID) } if command.Function == nil { // simple reply @@ -137,9 +139,11 @@ func executeCommand(session *discordgo.Session, message *discordgo.MessageCreate } } -func removeCooldown(commandIndex int) { +func removeCooldown(commandIndex int, uid string) { time.Sleep(time.Duration(commands[commandIndex].Cooldown) * time.Second) - commands[commandIndex].IsOnCooldown = false + if commands[commandIndex].UsersOnCooldown.Contains(uid) { + commands[commandIndex].UsersOnCooldown.Remove(uid) + } } func generateReply(message *discordgo.MessageCreate, command Command) string { diff --git a/helpers.go b/helpers.go index 2e1ede9..b83bb08 100644 --- a/helpers.go +++ b/helpers.go @@ -38,3 +38,10 @@ func getUser(s *discordgo.Session, uid string) *discordgo.User { return user } +func isDM(s *discordgo.Session, m *discordgo.MessageCreate) bool { + return (getChannel(s.State, m.ChannelID).Type == discordgo.ChannelTypeDM) +} + +func isAdmin(u *discordgo.User) bool { + return (u.ID == config.AdminID) +}