diff --git a/README.md b/README.md index f281c18..f764f48 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,13 @@ # bolt +look into using retries and context's + +also add mentioned users to struct so the ban functions can use them + +also do we move towards message handler or focus on command handling, this will dictate the structure going forward with the bans, etc. +started on msg handler but it doesn't make sense, the Adding of the handlers won't work with messages since the trigger is always blank. Might need a catch all one, but then +that blacklists a command trigger, hmmmm + The nuts-and-bolts of Discord bots. Bolt is a wrapper for [discordgo](https://github.com/bwmarrin/discordgo) that provides quick and easy bootstrapping for simple Discord bots. ## Usage diff --git a/bolt.go b/bolt.go index b4ef3fe..4a9afdc 100644 --- a/bolt.go +++ b/bolt.go @@ -1,12 +1,14 @@ package bolt import ( + "context" "fmt" "log" "os" "os/signal" "slices" "strings" + "sync" "syscall" "time" @@ -14,14 +16,12 @@ import ( ) const ( - TOKEN_ENV_VAR = "DISCORD_TOKEN" //label for token environment variable + //Environment variable name for discord token, this is the only required variable + TOKEN_ENV_VAR = "DISCORD_TOKEN" - BOT_INTENTS = dg.IntentGuilds | - dg.IntentGuildMembers | - dg.IntentGuildPresences | - dg.IntentMessageContent | - dg.IntentsGuildMessages | - dg.IntentGuildMessageReactions + //bot defaults + DEFAULT_INDICATOR = "." + DEFAULT_MAX_GOROUTINES = 50 ) // basic bot structure containing discordgo connection as well as the command map @@ -30,6 +30,10 @@ type bolt struct { commands map[string]Command //maps trigger phrase to command struct for fast lookup indicator string //the indicator used to detect whether a message is a command logLvl LogLevel //determines how much the bot logs + wg sync.WaitGroup + pool chan struct{} + maxRoutines int + msgHandlerf Payload } type Bolt interface { @@ -37,15 +41,15 @@ type Bolt interface { AddCommands(cmd ...Command) //filtered methods stop() error - messageHandler(s *dg.Session, msg *dg.MessageCreate) - handleCommand(msgEvent *MessageCreateEvent, s *dg.Session, lg int) error + msgEventHandler(s *dg.Session, msg *dg.MessageCreate) + handleCommand(msgEvent *Message, lg int) error + handleMessage(event *Message) error createReply(content, message, channel, guild string) *dg.MessageSend - getRemainingTimeout(timeout time.Time) string + remainingTimeout(timeout time.Time) string roleCheck(guild string, roles []string, s *dg.Session, run Command) (bool, error) timeoutCheck(msgID, channelID, guildID string, s *dg.Session, run Command) (bool, error) } -// create a new bolt interface func New(opts ...Option) (Bolt, error) { _, check := os.LookupEnv(TOKEN_ENV_VAR) if !check { @@ -56,48 +60,65 @@ func New(opts ...Option) (Bolt, error) { if err != nil { return nil, fmt.Errorf("failed to create Discord session: %e", err) } - bot.Identify.Intents = BOT_INTENTS b := &bolt{ - Session: bot, - commands: make(map[string]Command, 0), - logLvl: LogLevelAll, + Session: bot, + commands: make(map[string]Command, 0), + logLvl: LogLevelAll, + indicator: DEFAULT_INDICATOR, + wg: sync.WaitGroup{}, + maxRoutines: DEFAULT_MAX_GOROUTINES, } - //set default command indicator - b.indicator = "." //apply options for _, opt := range opts { opt(b) } + //options can change pool size, create post-options + b.pool = make(chan struct{}, b.maxRoutines) + return b, nil } -// starts the bot, commands are added and the connection to Discord is opened, this is a BLOCKING call -// that handles safe shutdown of the bot func (b *bolt) Start() error { - //register commands and open connection - b.AddHandler(b.messageHandler) - + b.AddHandler(b.msgEventHandler) err := b.Open() if err != nil { return fmt.Errorf("failed to open websocket connection with Discord: %e", err) } - //safe shutdown handler + log.Println("bot started") + sigChannel := make(chan os.Signal, 1) signal.Notify(sigChannel, syscall.SIGINT) <-sigChannel + //move this to an option, maybe? + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + closeChan := make(chan struct{}, 0) + go func() { + b.wg.Wait() + close(closeChan) + }() + + select { + case <-ctx.Done(): + log.Println("shutdown timed out waiting for commands to finish, some may have been incomplete") + case <-closeChan: + log.Println("command routines cleaned up, exiting") + } + if err := b.stop(); err != nil { return err } + log.Println("bot stopped") + return nil } -// stops the bot func (b *bolt) stop() error { return b.Close() } @@ -109,9 +130,7 @@ func (b *bolt) AddCommands(cmd ...Command) { } } -// handler function that parses message data, handles logging the message based on logLevel, and executes -// the payload function in a goroutine -func (b *bolt) messageHandler(s *dg.Session, msg *dg.MessageCreate) { +func (b *bolt) msgEventHandler(s *dg.Session, msg *dg.MessageCreate) { //get server information server, err := s.Guild(msg.GuildID) if err != nil { @@ -124,15 +143,9 @@ func (b *bolt) messageHandler(s *dg.Session, msg *dg.MessageCreate) { return } - //if there is no content it is likely an image, gif, or sticker, updating message content for - //better logging and to avoid confusion - if len(msg.Content) == 0 { - msg.Content = "[Embedded Content]" - } - //the bot will ignore it's own messages to prevent command loops if msg.Author.ID == s.State.User.ID { - if b.logLvl == LogLevelCmd || b.logLvl == LogLevelAll { + if b.logLvl != LogLevelErr && b.logLvl != LogLevelNone { //log command responses log.Printf("< %s | %s | %s > %s\n", server.Name, channel.Name, msg.Author.Username, msg.Content) } @@ -144,86 +157,31 @@ func (b *bolt) messageHandler(s *dg.Session, msg *dg.MessageCreate) { log.Printf("< %s | %s | %s > %s\n", server.Name, channel.Name, msg.Author.Username, msg.Content) } - //does the message have the command indicator - lg := len(b.indicator) - if msg.Content[:lg] == b.indicator { - mCreate := &MessageCreateEvent{ - AuthorUsername: msg.Author.Username, - AuthorID: msg.Author.ID, - AuthorRoles: msg.Member.Roles, - MsgID: msg.ID, - Msg: msg.Content, - MsgChanID: msg.ChannelID, - MsgGuildID: msg.GuildID, - MsgAttachments: msg.Attachments, - } - - if b.logLvl == LogLevelCmd { - //log commands - log.Printf("< %s | %s | %s > %s\n", mCreate.MsgGuildName, mCreate.MsgChanName, mCreate.AuthorUsername, mCreate.Msg) - } - - //handled in its own goroutine to allow for async commands - go func() { - err := b.handleCommand(mCreate, s, lg) - if err != nil { - log.Println(err) - } - }() - } -} - -// parses command from message and handles timeout checks, role checks, and command execution. All command responses are sent back to Discord -func (b *bolt) handleCommand(msgEvent *MessageCreateEvent, s *dg.Session, lg int) error { - words := strings.Split(msgEvent.Msg, " ") - run, ok := b.commands[words[0][lg:]] - if !ok { - return nil //command doesn't exist, maybe log or respond to author + m := Message{ + Author: msg.Author.Username, + authorID: msg.Author.ID, + authorRoles: msg.Member.Roles, + ID: msg.ID, + Content: msg.Content, + Channel: channel.Name, + channelID: channel.ID, + Server: server.Name, + serverID: server.ID, + sesh: b, } - //has command met its timeout requirements - tc, err := b.timeoutCheck(msgEvent.MsgID, msgEvent.MsgChanID, msgEvent.MsgGuildID, s, run) - if err != nil { - return fmt.Errorf("failed to calculate timeout for %s\n%e", run.Trigger, err) - } - if !tc { - return nil + w := strings.Fields(msg.Content) + if len(w) > 0 { + m.Words = w } - //does user have correct permissions - if run.Roles != nil { - check, err := b.roleCheck(msgEvent.MsgGuildID, msgEvent.AuthorRoles, s, run) - if err != nil { - return fmt.Errorf("failed to perform permission checks for %s\n%e", run.Trigger, err) - } - if !check { - reply := b.createReply("you do not have permissions to run that command", msgEvent.MsgID, msgEvent.MsgChanID, msgEvent.MsgGuildID) - _, err := s.ChannelMessageSendComplex(msgEvent.MsgChanID, reply) - if err != nil { - return err - } - return nil - } + if len(msg.Mentions) > 0 { + m.Mentions = msg.Mentions } - //populate message struct exposed to client - plMsg := Message{ - Author: msgEvent.AuthorUsername, - ID: msgEvent.AuthorID, - msgID: msgEvent.MsgID, - Words: words, - Content: msgEvent.Msg, - Channel: msgEvent.MsgChanName, - channelID: msgEvent.MsgChanID, - Server: msgEvent.MsgGuildName, - serverID: msgEvent.MsgGuildID, - sesh: b, - } - - //check for file attachments - if len(msgEvent.MsgAttachments) > 0 { + if len(msg.Attachments) > 0 { var att []MessageAttachment - for _, a := range msgEvent.MsgAttachments { + for _, a := range msg.Attachments { att = append(att, MessageAttachment{ ID: a.ID, URL: a.URL, @@ -237,13 +195,85 @@ func (b *bolt) handleCommand(msgEvent *MessageCreateEvent, s *dg.Session, lg int }) } - plMsg.Attachments = att + m.Attachments = att } - //run command payload - err = run.Payload(plMsg) + //using a patter based on a stackoverflow comment I saw that mentioned the use of a buffered channel as a lock (semaphore) + //to limit the amount of goroutines used at once + + //could be an issue if the bot is used like a long-term calendar, not sure that is my concern we now have a timeout so it will only wait so long + + lg := len(b.indicator) + if msg.Content[:lg] == b.indicator { + if b.logLvl == LogLevelCmd { + //log commands + log.Printf("< %s | %s | %s > %s\n", m.Server, m.Channel, m.authorID, m.Content) + } + + b.pool <- struct{}{} //'aquire' a routine + + //handled in its own goroutine to allow for async commands + b.wg.Go(func() { + err := b.handleCommand(&m, lg) + if err != nil { + log.Println(err) + } + <-b.pool //release routine + }) + } else { + b.pool <- struct{}{} //'aquire' a routine + b.wg.Go(func() { + err := b.handleMessage(&m) + if err != nil { + log.Println(err) + } + <-b.pool //release routine + }) + } +} + +func (b *bolt) handleMessage(event *Message) error { + if b.msgHandlerf != nil { + return b.msgHandlerf(event) + } + + return nil +} + +func (b *bolt) handleCommand(msg *Message, lg int) error { + run, ok := b.commands[msg.Words[0][lg:]] + if !ok { + return nil //command doesn't exist, maybe log or respond to author + } + + //has command met its timeout requirements + tc, err := b.timeoutCheck(msg.ID, msg.channelID, msg.serverID, b.Session, run) if err != nil { - return fmt.Errorf("failed to execute payload function: %e", err) + return fmt.Errorf("failed to calculate timeout for %s\n%e", run.Trigger, err) + } + if !tc { + return nil + } + + //does user have correct permissions + if run.Roles != nil { + check, err := b.roleCheck(msg.serverID, msg.authorRoles, b.Session, run) + if err != nil { + return fmt.Errorf("failed to perform permission checks for %s\n%e", run.Trigger, err) + } + if !check { + reply := b.createReply("you do not have permissions to run that command", msg.ID, msg.channelID, msg.serverID) + _, err := b.Session.ChannelMessageSendComplex(msg.channelID, reply) + if err != nil { + return err + } + return nil + } + } + + err = run.Payload(msg) + if err != nil { + return fmt.Errorf("encountered an error while handling command (%s): %e", msg.Words[0], err) } //update run time @@ -267,7 +297,7 @@ func (b *bolt) createReply(content, message, channel, guild string) *dg.MessageS } // used to calculate the remaining time left in a timeout and returning it in a human-readable format -func (b *bolt) getRemainingTimeout(timeout time.Time) string { +func (b *bolt) remainingTimeout(timeout time.Time) string { r := time.Until(timeout) var ( timeLeft int @@ -314,11 +344,10 @@ func (b *bolt) roleCheck(guild string, roles []string, s *dg.Session, run Comman return true, nil } -// check if the command timeout has been met, responding with remaining time if timeout has not been met yet. func (b *bolt) timeoutCheck(msgID, channelID, guildID string, s *dg.Session, run Command) (bool, error) { wait := run.lastRun.Add(run.Timeout) if !time.Now().After(wait) { - reply := b.createReply(fmt.Sprintf("that command cannot be run for another %s", b.getRemainingTimeout(wait)), msgID, channelID, guildID) + reply := b.createReply(fmt.Sprintf("that command cannot be run for another %s", b.remainingTimeout(wait)), msgID, channelID, guildID) _, err := s.ChannelMessageSendComplex(channelID, reply) if err != nil { return false, fmt.Errorf("failed to send timeout response: %e", err) diff --git a/command.go b/command.go index a14a70d..0dbe9c9 100644 --- a/command.go +++ b/command.go @@ -14,4 +14,4 @@ type Command struct { } // command payload functions, any strings returned are sent as a response to the command -type Payload func(msg Message) error +type Payload func(msg *Message) error diff --git a/go.mod b/go.mod index 1d9a5fa..9f95850 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module code.jakeyoungdev.com/jake/bolt -go 1.24.0 +go 1.25.0 require github.com/bwmarrin/discordgo v0.29.0 diff --git a/message.go b/message.go index 59599a0..3e72902 100644 --- a/message.go +++ b/message.go @@ -2,44 +2,48 @@ package bolt import ( "fmt" + "time" dg "github.com/bwmarrin/discordgo" ) const ( - // the max discord allows for basic messages + // the max length allowed for basic messages MSG_MAX_LENGTH = 2000 ) -// the message struct is passed to the command payload providing basic -// message information and needed methods +// Message contains basic information about the messages received and provides a few methods +// for handling replies, bans, timeouts, reaction, and deletion. All Discord utilities will use +// a timeout to prevent hanging for too long, this timeout can be customized with the WithTimeout +// option. type Message struct { - Author string //username of message author - ID string //discord ID of message author - msgID string //id string of message - Words []string //words from message split on whitespace - Content string //entire message content - Channel string //message channel - channelID string //id of channel message was sent in - Server string //message guild - serverID string //id of guild message was sent in - Attachments []MessageAttachment + Author string //current username of the message author + authorID string //discord ID of message author + authorRoles []string + ID string //message ID + Words []string //message data split on whitespaces + Content string //entire message data string + Channel string //name of channel message was sent in + channelID string //ID of channel message was sent in + Server string //name of guild message was sent in + serverID string //ID of guild message was sent in + Attachments []MessageAttachment //any attachments bound to the message + Mentions []*dg.User sesh *bolt } -// applies reaction to message +// React applies reaction to the message func (m *Message) React(emoji Reaction) error { - return m.sesh.MessageReactionAdd(m.channelID, m.msgID, fmt.Sprint(emoji)) + return m.sesh.MessageReactionAdd(m.channelID, m.ID, fmt.Sprint(emoji)) } -// sends response to message, if the response length is greater than 2000 characters the -// messages are split and sent seperatly +// Respond sends a response to the message, handling chunking if the message exceeds max length func (m *Message) Respond(res string) error { if len(res) > MSG_MAX_LENGTH { for len(res) > 0 { //send full chunk size allowed by discord sc := res[:MSG_MAX_LENGTH] - rep := m.sesh.createReply(sc, m.msgID, m.channelID, m.serverID) + rep := m.sesh.createReply(sc, m.ID, m.channelID, m.serverID) _, err := m.sesh.ChannelMessageSendComplex(m.channelID, rep) if err != nil { return err @@ -48,7 +52,7 @@ func (m *Message) Respond(res string) error { //if we have left than a full chunk send the rest and break the loop if len(res) < MSG_MAX_LENGTH { - final := m.sesh.createReply(res, m.msgID, m.channelID, m.serverID) + final := m.sesh.createReply(res, m.ID, m.channelID, m.serverID) _, err := m.sesh.ChannelMessageSendComplex(m.channelID, final) if err != nil { return err @@ -61,30 +65,44 @@ func (m *Message) Respond(res string) error { return nil } - //short enough message to send in one message - rep := m.sesh.createReply(res, m.msgID, m.channelID, m.serverID) + //short enough message to send in one go + rep := m.sesh.createReply(res, m.ID, m.channelID, m.serverID) _, err := m.sesh.ChannelMessageSendComplex(m.channelID, rep) return err } -// deletes the message from the channel +// Delete removes the message from the current channel func (m *Message) Delete() error { - return m.sesh.ChannelMessageDelete(m.channelID, m.msgID, nil) + return m.sesh.ChannelMessageDelete(m.channelID, m.ID, nil) } -// this struct has all of the needed information from the messageCreate event so that -// commands can be run asynchronously. Passing the messageCreate to payloads can block routines -type MessageCreateEvent struct { - AuthorUsername string - AuthorID string - AuthorRoles []string - MsgID string - Msg string - MsgChanID string - MsgChanName string - MsgGuildID string - MsgGuildName string - MsgAttachments []*dg.MessageAttachment +// Timeout sets a timeout for the message author +func (m *Message) Timeout(duration time.Time) error { + return m.sesh.GuildMemberTimeout(m.serverID, m.authorID, &duration) +} + +// ClearTimeout removes all timeouts for the message author +func (m *Message) ClearTimeout() error { + return m.sesh.GuildMemberTimeout(m.serverID, m.authorID, nil) +} + +// Ban removes a user from the server, banning them and removing all messages within the range of +// the days parameter +func (m *Message) Ban(reason string, days int) error { + return m.sesh.GuildBanCreateWithReason(m.serverID, m.authorID, reason, days) +} + +// unban user + +// ClearBan deletes the ban on message Authors + +// lol this won't work, they're banned, same with all clear* +func (m *Message) ClearBan() error { + return m.sesh.GuildBanDelete(m.serverID, m.authorID) +} + +func (m *Message) Mute(username string) error { + return m.sesh.GuildMemberMute(m.serverID, m.authorID, true) } // message attachment details diff --git a/option.go b/option.go index 9cac768..b629baa 100644 --- a/option.go +++ b/option.go @@ -1,15 +1,52 @@ package bolt +import ( + dg "github.com/bwmarrin/discordgo" +) + type Option func(b *bolt) type LogLevel int +type Permission dg.Intent + +type HandlerLevel int + const ( - LogLevelAll LogLevel = iota //logs all messages, and errors - LogLevelCmd LogLevel = iota //log only commands and responses, and errors - LogLevelErr LogLevel = iota //logs only errors + LogLevelAll LogLevel = iota //log all messages, and errors + LogLevelCmd LogLevel = iota //log only commands and responses, and errors + LogLevelErr LogLevel = iota //log only errors + LogLevelNone LogLevel = iota //log nothing, let the handlers sort it out + + msgPerms dg.Intent = dg.IntentGuilds | + dg.IntentGuildMembers | + dg.IntentGuildPresences | + dg.IntentMessageContent | + dg.IntentsGuildMessages + + MessagePermissions Permission = Permission(msgPerms) + ReactionPermissions Permission = Permission(dg.IntentGuildMessageReactions) + //we also need a ModeratorPermissions for banning, kicking, etc. ) +func WithPermissions(perms ...Permission) Option { + return func(b *bolt) { + var fullPerms dg.Intent + for _, p := range perms { + fullPerms |= dg.Intent(p) + } + + //set intents + b.Identify.Intents = fullPerms + } +} + +func WithMaxGoroutines(max int) Option { + return func(b *bolt) { + b.maxRoutines = max + } +} + // sets the substring that must be present at the beginning of the message to indicate a command func WithIndicator(i string) Option { return func(b *bolt) {