started the facelift of the repo, adding in bans and timeouts. That comes with some restructure. Also implementing a goroutine limit for command handlers and options around that, as well as moving the intents to options to allow stronger restrictions
360 lines
8.9 KiB
Go
360 lines
8.9 KiB
Go
package bolt
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
dg "github.com/bwmarrin/discordgo"
|
|
)
|
|
|
|
const (
|
|
//Environment variable name for discord token, this is the only required variable
|
|
TOKEN_ENV_VAR = "DISCORD_TOKEN"
|
|
|
|
//bot defaults
|
|
DEFAULT_INDICATOR = "."
|
|
DEFAULT_MAX_GOROUTINES = 50
|
|
)
|
|
|
|
// basic bot structure containing discordgo connection as well as the command map
|
|
type bolt struct {
|
|
*dg.Session //holds discordgo internals
|
|
commands map[string]Command //maps trigger phrase to command struct for fast lookup
|
|
indicator string //the indicator used to detect whether a message is a command
|
|
logLvl LogLevel //determines how much the bot logs
|
|
wg sync.WaitGroup
|
|
pool chan struct{}
|
|
maxRoutines int
|
|
msgHandlerf Payload
|
|
}
|
|
|
|
type Bolt interface {
|
|
Start() error
|
|
AddCommands(cmd ...Command)
|
|
//filtered methods
|
|
stop() error
|
|
msgEventHandler(s *dg.Session, msg *dg.MessageCreate)
|
|
handleCommand(msgEvent *Message, lg int) error
|
|
handleMessage(event *Message) error
|
|
createReply(content, message, channel, guild string) *dg.MessageSend
|
|
remainingTimeout(timeout time.Time) string
|
|
roleCheck(guild string, roles []string, s *dg.Session, run Command) (bool, error)
|
|
timeoutCheck(msgID, channelID, guildID string, s *dg.Session, run Command) (bool, error)
|
|
}
|
|
|
|
func New(opts ...Option) (Bolt, error) {
|
|
_, check := os.LookupEnv(TOKEN_ENV_VAR)
|
|
if !check {
|
|
return nil, fmt.Errorf("environment variable %s must be set", TOKEN_ENV_VAR)
|
|
}
|
|
|
|
bot, err := dg.New(fmt.Sprintf("Bot %s", os.Getenv(TOKEN_ENV_VAR)))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Discord session: %e", err)
|
|
}
|
|
|
|
b := &bolt{
|
|
Session: bot,
|
|
commands: make(map[string]Command, 0),
|
|
logLvl: LogLevelAll,
|
|
indicator: DEFAULT_INDICATOR,
|
|
wg: sync.WaitGroup{},
|
|
maxRoutines: DEFAULT_MAX_GOROUTINES,
|
|
}
|
|
|
|
//apply options
|
|
for _, opt := range opts {
|
|
opt(b)
|
|
}
|
|
|
|
//options can change pool size, create post-options
|
|
b.pool = make(chan struct{}, b.maxRoutines)
|
|
|
|
return b, nil
|
|
}
|
|
|
|
func (b *bolt) Start() error {
|
|
b.AddHandler(b.msgEventHandler)
|
|
err := b.Open()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open websocket connection with Discord: %e", err)
|
|
}
|
|
|
|
log.Println("bot started")
|
|
|
|
sigChannel := make(chan os.Signal, 1)
|
|
signal.Notify(sigChannel, syscall.SIGINT)
|
|
<-sigChannel
|
|
|
|
//move this to an option, maybe?
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
closeChan := make(chan struct{}, 0)
|
|
go func() {
|
|
b.wg.Wait()
|
|
close(closeChan)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Println("shutdown timed out waiting for commands to finish, some may have been incomplete")
|
|
case <-closeChan:
|
|
log.Println("command routines cleaned up, exiting")
|
|
}
|
|
|
|
if err := b.stop(); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Println("bot stopped")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (b *bolt) stop() error {
|
|
return b.Close()
|
|
}
|
|
|
|
// adds commands to bot command map for use
|
|
func (b *bolt) AddCommands(cmd ...Command) {
|
|
for _, c := range cmd {
|
|
b.commands[c.Trigger] = c
|
|
}
|
|
}
|
|
|
|
func (b *bolt) msgEventHandler(s *dg.Session, msg *dg.MessageCreate) {
|
|
//get server information
|
|
server, err := s.Guild(msg.GuildID)
|
|
if err != nil {
|
|
log.Printf("failed to get guild: %e\n", err)
|
|
return
|
|
}
|
|
channel, err := s.Channel(msg.ChannelID)
|
|
if err != nil {
|
|
log.Printf("failed to get channel from guild: %e\n", err)
|
|
return
|
|
}
|
|
|
|
//the bot will ignore it's own messages to prevent command loops
|
|
if msg.Author.ID == s.State.User.ID {
|
|
if b.logLvl != LogLevelErr && b.logLvl != LogLevelNone {
|
|
//log command responses
|
|
log.Printf("< %s | %s | %s > %s\n", server.Name, channel.Name, msg.Author.Username, msg.Content)
|
|
}
|
|
return
|
|
}
|
|
|
|
if b.logLvl == LogLevelAll {
|
|
//log message
|
|
log.Printf("< %s | %s | %s > %s\n", server.Name, channel.Name, msg.Author.Username, msg.Content)
|
|
}
|
|
|
|
m := Message{
|
|
Author: msg.Author.Username,
|
|
authorID: msg.Author.ID,
|
|
authorRoles: msg.Member.Roles,
|
|
ID: msg.ID,
|
|
Content: msg.Content,
|
|
Channel: channel.Name,
|
|
channelID: channel.ID,
|
|
Server: server.Name,
|
|
serverID: server.ID,
|
|
sesh: b,
|
|
}
|
|
|
|
w := strings.Fields(msg.Content)
|
|
if len(w) > 0 {
|
|
m.Words = w
|
|
}
|
|
|
|
if len(msg.Mentions) > 0 {
|
|
m.Mentions = msg.Mentions
|
|
}
|
|
|
|
if len(msg.Attachments) > 0 {
|
|
var att []MessageAttachment
|
|
for _, a := range msg.Attachments {
|
|
att = append(att, MessageAttachment{
|
|
ID: a.ID,
|
|
URL: a.URL,
|
|
ProxyURL: a.ProxyURL,
|
|
Filename: a.Filename,
|
|
ContentType: a.ContentType,
|
|
Width: a.Width,
|
|
Height: a.Height,
|
|
Size: a.Size,
|
|
DurationSecs: a.DurationSecs,
|
|
})
|
|
}
|
|
|
|
m.Attachments = att
|
|
}
|
|
|
|
//using a patter based on a stackoverflow comment I saw that mentioned the use of a buffered channel as a lock (semaphore)
|
|
//to limit the amount of goroutines used at once
|
|
|
|
//could be an issue if the bot is used like a long-term calendar, not sure that is my concern we now have a timeout so it will only wait so long
|
|
|
|
lg := len(b.indicator)
|
|
if msg.Content[:lg] == b.indicator {
|
|
if b.logLvl == LogLevelCmd {
|
|
//log commands
|
|
log.Printf("< %s | %s | %s > %s\n", m.Server, m.Channel, m.authorID, m.Content)
|
|
}
|
|
|
|
b.pool <- struct{}{} //'aquire' a routine
|
|
|
|
//handled in its own goroutine to allow for async commands
|
|
b.wg.Go(func() {
|
|
err := b.handleCommand(&m, lg)
|
|
if err != nil {
|
|
log.Println(err)
|
|
}
|
|
<-b.pool //release routine
|
|
})
|
|
} else {
|
|
b.pool <- struct{}{} //'aquire' a routine
|
|
b.wg.Go(func() {
|
|
err := b.handleMessage(&m)
|
|
if err != nil {
|
|
log.Println(err)
|
|
}
|
|
<-b.pool //release routine
|
|
})
|
|
}
|
|
}
|
|
|
|
func (b *bolt) handleMessage(event *Message) error {
|
|
if b.msgHandlerf != nil {
|
|
return b.msgHandlerf(event)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (b *bolt) handleCommand(msg *Message, lg int) error {
|
|
run, ok := b.commands[msg.Words[0][lg:]]
|
|
if !ok {
|
|
return nil //command doesn't exist, maybe log or respond to author
|
|
}
|
|
|
|
//has command met its timeout requirements
|
|
tc, err := b.timeoutCheck(msg.ID, msg.channelID, msg.serverID, b.Session, run)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to calculate timeout for %s\n%e", run.Trigger, err)
|
|
}
|
|
if !tc {
|
|
return nil
|
|
}
|
|
|
|
//does user have correct permissions
|
|
if run.Roles != nil {
|
|
check, err := b.roleCheck(msg.serverID, msg.authorRoles, b.Session, run)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to perform permission checks for %s\n%e", run.Trigger, err)
|
|
}
|
|
if !check {
|
|
reply := b.createReply("you do not have permissions to run that command", msg.ID, msg.channelID, msg.serverID)
|
|
_, err := b.Session.ChannelMessageSendComplex(msg.channelID, reply)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
err = run.Payload(msg)
|
|
if err != nil {
|
|
return fmt.Errorf("encountered an error while handling command (%s): %e", msg.Words[0], err)
|
|
}
|
|
|
|
//update run time
|
|
run.lastRun = time.Now()
|
|
b.commands[run.Trigger] = run
|
|
return nil
|
|
}
|
|
|
|
// basic wrapper function to create easy Discord responses
|
|
func (b *bolt) createReply(content, message, channel, guild string) *dg.MessageSend {
|
|
details := &dg.MessageReference{
|
|
MessageID: message,
|
|
ChannelID: channel,
|
|
GuildID: guild,
|
|
}
|
|
|
|
return &dg.MessageSend{
|
|
Content: content,
|
|
Reference: details,
|
|
}
|
|
}
|
|
|
|
// used to calculate the remaining time left in a timeout and returning it in a human-readable format
|
|
func (b *bolt) remainingTimeout(timeout time.Time) string {
|
|
r := time.Until(timeout)
|
|
var (
|
|
timeLeft int
|
|
metric string
|
|
)
|
|
timeLeft = int(r.Hours())
|
|
metric = "h"
|
|
if timeLeft < 1 {
|
|
timeLeft = int(r.Minutes())
|
|
metric = "m"
|
|
if timeLeft < 1 {
|
|
timeLeft = int(r.Seconds())
|
|
metric = "s"
|
|
}
|
|
}
|
|
|
|
return fmt.Sprintf("%d%s", timeLeft, metric)
|
|
}
|
|
|
|
// checks if the author of msg has the correct role to run the requested command
|
|
func (b *bolt) roleCheck(guild string, roles []string, s *dg.Session, run Command) (bool, error) {
|
|
var found bool
|
|
//loop thru author roles, there may be a better way to check for this UNION
|
|
//TODO: improve role search performance to support bigger lists
|
|
for _, r := range roles {
|
|
//get role name from ID
|
|
n, err := s.State.Role(guild, r)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to get role from ID %s\n%e", guild, err)
|
|
}
|
|
//does this role exist in command roles
|
|
check := slices.Contains(run.Roles, n.Name)
|
|
if check {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
//can't find role, don't run command
|
|
if !found {
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (b *bolt) timeoutCheck(msgID, channelID, guildID string, s *dg.Session, run Command) (bool, error) {
|
|
wait := run.lastRun.Add(run.Timeout)
|
|
if !time.Now().After(wait) {
|
|
reply := b.createReply(fmt.Sprintf("that command cannot be run for another %s", b.remainingTimeout(wait)), msgID, channelID, guildID)
|
|
_, err := s.ChannelMessageSendComplex(channelID, reply)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to send timeout response: %e", err)
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|