started the facelift of the repo, adding in bans and timeouts. That comes with some restructure.
Also implementing a goroutine limit for command handlers and options around that, as well as moving
the intents to options to allow stronger restrictions
This commit is contained in:
2026-02-24 16:54:25 -05:00
parent c291f68005
commit 6816d7359b
6 changed files with 247 additions and 155 deletions

257
bolt.go
View File

@@ -1,12 +1,14 @@
package bolt
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"slices"
"strings"
"sync"
"syscall"
"time"
@@ -14,14 +16,12 @@ import (
)
const (
TOKEN_ENV_VAR = "DISCORD_TOKEN" //label for token environment variable
//Environment variable name for discord token, this is the only required variable
TOKEN_ENV_VAR = "DISCORD_TOKEN"
BOT_INTENTS = dg.IntentGuilds |
dg.IntentGuildMembers |
dg.IntentGuildPresences |
dg.IntentMessageContent |
dg.IntentsGuildMessages |
dg.IntentGuildMessageReactions
//bot defaults
DEFAULT_INDICATOR = "."
DEFAULT_MAX_GOROUTINES = 50
)
// basic bot structure containing discordgo connection as well as the command map
@@ -30,6 +30,10 @@ type bolt struct {
commands map[string]Command //maps trigger phrase to command struct for fast lookup
indicator string //the indicator used to detect whether a message is a command
logLvl LogLevel //determines how much the bot logs
wg sync.WaitGroup
pool chan struct{}
maxRoutines int
msgHandlerf Payload
}
type Bolt interface {
@@ -37,15 +41,15 @@ type Bolt interface {
AddCommands(cmd ...Command)
//filtered methods
stop() error
messageHandler(s *dg.Session, msg *dg.MessageCreate)
handleCommand(msgEvent *MessageCreateEvent, s *dg.Session, lg int) error
msgEventHandler(s *dg.Session, msg *dg.MessageCreate)
handleCommand(msgEvent *Message, lg int) error
handleMessage(event *Message) error
createReply(content, message, channel, guild string) *dg.MessageSend
getRemainingTimeout(timeout time.Time) string
remainingTimeout(timeout time.Time) string
roleCheck(guild string, roles []string, s *dg.Session, run Command) (bool, error)
timeoutCheck(msgID, channelID, guildID string, s *dg.Session, run Command) (bool, error)
}
// create a new bolt interface
func New(opts ...Option) (Bolt, error) {
_, check := os.LookupEnv(TOKEN_ENV_VAR)
if !check {
@@ -56,48 +60,65 @@ func New(opts ...Option) (Bolt, error) {
if err != nil {
return nil, fmt.Errorf("failed to create Discord session: %e", err)
}
bot.Identify.Intents = BOT_INTENTS
b := &bolt{
Session: bot,
commands: make(map[string]Command, 0),
logLvl: LogLevelAll,
Session: bot,
commands: make(map[string]Command, 0),
logLvl: LogLevelAll,
indicator: DEFAULT_INDICATOR,
wg: sync.WaitGroup{},
maxRoutines: DEFAULT_MAX_GOROUTINES,
}
//set default command indicator
b.indicator = "."
//apply options
for _, opt := range opts {
opt(b)
}
//options can change pool size, create post-options
b.pool = make(chan struct{}, b.maxRoutines)
return b, nil
}
// starts the bot, commands are added and the connection to Discord is opened, this is a BLOCKING call
// that handles safe shutdown of the bot
func (b *bolt) Start() error {
//register commands and open connection
b.AddHandler(b.messageHandler)
b.AddHandler(b.msgEventHandler)
err := b.Open()
if err != nil {
return fmt.Errorf("failed to open websocket connection with Discord: %e", err)
}
//safe shutdown handler
log.Println("bot started")
sigChannel := make(chan os.Signal, 1)
signal.Notify(sigChannel, syscall.SIGINT)
<-sigChannel
//move this to an option, maybe?
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
closeChan := make(chan struct{}, 0)
go func() {
b.wg.Wait()
close(closeChan)
}()
select {
case <-ctx.Done():
log.Println("shutdown timed out waiting for commands to finish, some may have been incomplete")
case <-closeChan:
log.Println("command routines cleaned up, exiting")
}
if err := b.stop(); err != nil {
return err
}
log.Println("bot stopped")
return nil
}
// stops the bot
func (b *bolt) stop() error {
return b.Close()
}
@@ -109,9 +130,7 @@ func (b *bolt) AddCommands(cmd ...Command) {
}
}
// handler function that parses message data, handles logging the message based on logLevel, and executes
// the payload function in a goroutine
func (b *bolt) messageHandler(s *dg.Session, msg *dg.MessageCreate) {
func (b *bolt) msgEventHandler(s *dg.Session, msg *dg.MessageCreate) {
//get server information
server, err := s.Guild(msg.GuildID)
if err != nil {
@@ -124,15 +143,9 @@ func (b *bolt) messageHandler(s *dg.Session, msg *dg.MessageCreate) {
return
}
//if there is no content it is likely an image, gif, or sticker, updating message content for
//better logging and to avoid confusion
if len(msg.Content) == 0 {
msg.Content = "[Embedded Content]"
}
//the bot will ignore it's own messages to prevent command loops
if msg.Author.ID == s.State.User.ID {
if b.logLvl == LogLevelCmd || b.logLvl == LogLevelAll {
if b.logLvl != LogLevelErr && b.logLvl != LogLevelNone {
//log command responses
log.Printf("< %s | %s | %s > %s\n", server.Name, channel.Name, msg.Author.Username, msg.Content)
}
@@ -144,86 +157,31 @@ func (b *bolt) messageHandler(s *dg.Session, msg *dg.MessageCreate) {
log.Printf("< %s | %s | %s > %s\n", server.Name, channel.Name, msg.Author.Username, msg.Content)
}
//does the message have the command indicator
lg := len(b.indicator)
if msg.Content[:lg] == b.indicator {
mCreate := &MessageCreateEvent{
AuthorUsername: msg.Author.Username,
AuthorID: msg.Author.ID,
AuthorRoles: msg.Member.Roles,
MsgID: msg.ID,
Msg: msg.Content,
MsgChanID: msg.ChannelID,
MsgGuildID: msg.GuildID,
MsgAttachments: msg.Attachments,
}
if b.logLvl == LogLevelCmd {
//log commands
log.Printf("< %s | %s | %s > %s\n", mCreate.MsgGuildName, mCreate.MsgChanName, mCreate.AuthorUsername, mCreate.Msg)
}
//handled in its own goroutine to allow for async commands
go func() {
err := b.handleCommand(mCreate, s, lg)
if err != nil {
log.Println(err)
}
}()
}
}
// parses command from message and handles timeout checks, role checks, and command execution. All command responses are sent back to Discord
func (b *bolt) handleCommand(msgEvent *MessageCreateEvent, s *dg.Session, lg int) error {
words := strings.Split(msgEvent.Msg, " ")
run, ok := b.commands[words[0][lg:]]
if !ok {
return nil //command doesn't exist, maybe log or respond to author
m := Message{
Author: msg.Author.Username,
authorID: msg.Author.ID,
authorRoles: msg.Member.Roles,
ID: msg.ID,
Content: msg.Content,
Channel: channel.Name,
channelID: channel.ID,
Server: server.Name,
serverID: server.ID,
sesh: b,
}
//has command met its timeout requirements
tc, err := b.timeoutCheck(msgEvent.MsgID, msgEvent.MsgChanID, msgEvent.MsgGuildID, s, run)
if err != nil {
return fmt.Errorf("failed to calculate timeout for %s\n%e", run.Trigger, err)
}
if !tc {
return nil
w := strings.Fields(msg.Content)
if len(w) > 0 {
m.Words = w
}
//does user have correct permissions
if run.Roles != nil {
check, err := b.roleCheck(msgEvent.MsgGuildID, msgEvent.AuthorRoles, s, run)
if err != nil {
return fmt.Errorf("failed to perform permission checks for %s\n%e", run.Trigger, err)
}
if !check {
reply := b.createReply("you do not have permissions to run that command", msgEvent.MsgID, msgEvent.MsgChanID, msgEvent.MsgGuildID)
_, err := s.ChannelMessageSendComplex(msgEvent.MsgChanID, reply)
if err != nil {
return err
}
return nil
}
if len(msg.Mentions) > 0 {
m.Mentions = msg.Mentions
}
//populate message struct exposed to client
plMsg := Message{
Author: msgEvent.AuthorUsername,
ID: msgEvent.AuthorID,
msgID: msgEvent.MsgID,
Words: words,
Content: msgEvent.Msg,
Channel: msgEvent.MsgChanName,
channelID: msgEvent.MsgChanID,
Server: msgEvent.MsgGuildName,
serverID: msgEvent.MsgGuildID,
sesh: b,
}
//check for file attachments
if len(msgEvent.MsgAttachments) > 0 {
if len(msg.Attachments) > 0 {
var att []MessageAttachment
for _, a := range msgEvent.MsgAttachments {
for _, a := range msg.Attachments {
att = append(att, MessageAttachment{
ID: a.ID,
URL: a.URL,
@@ -237,13 +195,85 @@ func (b *bolt) handleCommand(msgEvent *MessageCreateEvent, s *dg.Session, lg int
})
}
plMsg.Attachments = att
m.Attachments = att
}
//run command payload
err = run.Payload(plMsg)
//using a patter based on a stackoverflow comment I saw that mentioned the use of a buffered channel as a lock (semaphore)
//to limit the amount of goroutines used at once
//could be an issue if the bot is used like a long-term calendar, not sure that is my concern we now have a timeout so it will only wait so long
lg := len(b.indicator)
if msg.Content[:lg] == b.indicator {
if b.logLvl == LogLevelCmd {
//log commands
log.Printf("< %s | %s | %s > %s\n", m.Server, m.Channel, m.authorID, m.Content)
}
b.pool <- struct{}{} //'aquire' a routine
//handled in its own goroutine to allow for async commands
b.wg.Go(func() {
err := b.handleCommand(&m, lg)
if err != nil {
log.Println(err)
}
<-b.pool //release routine
})
} else {
b.pool <- struct{}{} //'aquire' a routine
b.wg.Go(func() {
err := b.handleMessage(&m)
if err != nil {
log.Println(err)
}
<-b.pool //release routine
})
}
}
func (b *bolt) handleMessage(event *Message) error {
if b.msgHandlerf != nil {
return b.msgHandlerf(event)
}
return nil
}
func (b *bolt) handleCommand(msg *Message, lg int) error {
run, ok := b.commands[msg.Words[0][lg:]]
if !ok {
return nil //command doesn't exist, maybe log or respond to author
}
//has command met its timeout requirements
tc, err := b.timeoutCheck(msg.ID, msg.channelID, msg.serverID, b.Session, run)
if err != nil {
return fmt.Errorf("failed to execute payload function: %e", err)
return fmt.Errorf("failed to calculate timeout for %s\n%e", run.Trigger, err)
}
if !tc {
return nil
}
//does user have correct permissions
if run.Roles != nil {
check, err := b.roleCheck(msg.serverID, msg.authorRoles, b.Session, run)
if err != nil {
return fmt.Errorf("failed to perform permission checks for %s\n%e", run.Trigger, err)
}
if !check {
reply := b.createReply("you do not have permissions to run that command", msg.ID, msg.channelID, msg.serverID)
_, err := b.Session.ChannelMessageSendComplex(msg.channelID, reply)
if err != nil {
return err
}
return nil
}
}
err = run.Payload(msg)
if err != nil {
return fmt.Errorf("encountered an error while handling command (%s): %e", msg.Words[0], err)
}
//update run time
@@ -267,7 +297,7 @@ func (b *bolt) createReply(content, message, channel, guild string) *dg.MessageS
}
// used to calculate the remaining time left in a timeout and returning it in a human-readable format
func (b *bolt) getRemainingTimeout(timeout time.Time) string {
func (b *bolt) remainingTimeout(timeout time.Time) string {
r := time.Until(timeout)
var (
timeLeft int
@@ -314,11 +344,10 @@ func (b *bolt) roleCheck(guild string, roles []string, s *dg.Session, run Comman
return true, nil
}
// check if the command timeout has been met, responding with remaining time if timeout has not been met yet.
func (b *bolt) timeoutCheck(msgID, channelID, guildID string, s *dg.Session, run Command) (bool, error) {
wait := run.lastRun.Add(run.Timeout)
if !time.Now().After(wait) {
reply := b.createReply(fmt.Sprintf("that command cannot be run for another %s", b.getRemainingTimeout(wait)), msgID, channelID, guildID)
reply := b.createReply(fmt.Sprintf("that command cannot be run for another %s", b.remainingTimeout(wait)), msgID, channelID, guildID)
_, err := s.ChannelMessageSendComplex(channelID, reply)
if err != nil {
return false, fmt.Errorf("failed to send timeout response: %e", err)