diff --git a/.gitignore b/.gitignore index 5470f7b..89debf0 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ go.work.sum # ignore the log directory configs/development/logs/ +.worktrees/ diff --git a/README.md b/README.md index 45c881f..43f637e 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,10 @@ Once you have built the project, you can run the `vault-audit-filter` executable audit_address: "127.0.0.1:1269" audit_description: "Vault Audit Filter Device" + async: + queue_size: 20 + timeout: 5s + rule_groups: - name: "normal_operations" rules: @@ -129,6 +133,10 @@ Once you have built the project, you can run the `vault-audit-filter` executable - `messaging.token`: The bot token for Slack (when using "slack" type). - `messaging.channel`: The channel ID for Slack messages (when using "slack" type). + - **Async Settings**: + - `async.queue_size`: Bounded queue length for async side effects (drop on full). + - `async.timeout`: Timeout for Slack API/webhook and forwarding operations. + ### Rule Syntax Rules are written using the `expr` language, a simple and safe expression language for Go. Rules can be based on the following properties of audit logs: diff --git a/pkg/auditserver/async.go b/pkg/auditserver/async.go new file mode 100644 index 0000000..4ad9592 --- /dev/null +++ b/pkg/auditserver/async.go @@ -0,0 +1,47 @@ +package auditserver + +import ( + "github.com/ncode/vault-audit-filter/pkg/forwarder" + "github.com/ncode/vault-audit-filter/pkg/messaging" +) + +var defaultSideWorkers = 2 + +type sideTask struct { + payload []byte + payloadStr string + messenger messaging.Messenger + forwarder forwarder.Forwarder +} + +func (as *AuditServer) enqueueSide(task sideTask) bool { + select { + case as.sideQueue <- task: + return true + default: + as.sideDrops.Add(1) + return false + } +} + +func (as *AuditServer) startSideWorkers(n int) { + if n <= 0 { + return + } + for i := 0; i < n; i++ { + go func() { + for task := range as.sideQueue { + if task.messenger != nil { + if err := task.messenger.Send(task.payloadStr); err != nil { + as.logger.Error("Failed to send notification", "error", err) + } + } + if task.forwarder != nil { + if err := task.forwarder.Forward(task.payload); err != nil { + as.logger.Error("Failed to forward message", "error", err) + } + } + } + }() + } +} diff --git a/pkg/auditserver/bench_test.go b/pkg/auditserver/bench_test.go new file mode 100644 index 0000000..39ea6f1 --- /dev/null +++ b/pkg/auditserver/bench_test.go @@ -0,0 +1,43 @@ +package auditserver + +import ( + "io" + "log/slog" + "testing" + + "github.com/expr-lang/expr" + "github.com/spf13/viper" +) + +func BenchmarkReact(b *testing.B) { + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelInfo})) + viper.Reset() + viper.Set("rule_groups", []map[string]interface{}{ + { + "name": "rg", + "rules": []string{"true"}, + "log_file": map[string]interface{}{ + "file_path": "/tmp/test.log", + "max_size": 1, + }, + }, + }) + server, _ := New(logger) + frame := []byte(`{"type":"request","time":"2000-01-01T00:00:00Z","auth":{},"request":{},"response":{}}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + server.React(frame, nil) + } +} + +func BenchmarkShouldLog(b *testing.B) { + p, _ := expr.Compile("true", expr.Env(&AuditLog{})) + rg := &RuleGroup{CompiledRules: []CompiledRule{{Program: p}}} + al := &AuditLog{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = rg.shouldLog(al) + } +} diff --git a/pkg/auditserver/server.go b/pkg/auditserver/server.go index 4573572..1bc5401 100644 --- a/pkg/auditserver/server.go +++ b/pkg/auditserver/server.go @@ -8,6 +8,7 @@ import ( "log/slog" "os" "sync" + "sync/atomic" "time" "github.com/expr-lang/expr" @@ -129,8 +130,12 @@ type LogFileConfig struct { type AuditServer struct { *gnet.EventServer - logger *slog.Logger - ruleGroups []RuleGroup + logger *slog.Logger + ruleGroups []RuleGroup + sideQueue chan sideTask + sideDrops atomic.Uint64 + asyncQueueSize int + asyncTimeout time.Duration } func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet.Action) { @@ -145,9 +150,11 @@ func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet return nil, gnet.Close } - shouldClose := false matched := false - forwarded := false + var payload []byte + var payloadStr string + payloadReady := false + payloadStrReady := false // Check each rule group for _, rg := range as.ruleGroups { @@ -155,20 +162,21 @@ func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet matched = true as.logger.Debug("Matched rule group", "group", rg.Name) - // Send notification if messenger is configured - if rg.Messenger != nil { - if err := rg.Messenger.Send(string(frame)); err != nil { - as.logger.Error("Failed to send notification", "error", err) - shouldClose = true + if rg.Messenger != nil || rg.Forwarder != nil { + if !payloadReady { + payload = append([]byte(nil), frame...) + payloadReady = true } - } - - if rg.Forwarder != nil { - if err := rg.Forwarder.Forward(frame); err != nil { - as.logger.Error("Failed to forward message", "error", err) - shouldClose = true + if rg.Messenger != nil && !payloadStrReady { + payloadStr = string(payload) + payloadStrReady = true } - forwarded = true + _ = as.enqueueSide(sideTask{ + payload: payload, + payloadStr: payloadStr, + messenger: rg.Messenger, + forwarder: rg.Forwarder, + }) } // zero‑copy write to log when possible @@ -177,7 +185,11 @@ func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet as.logger.Error("Failed to write audit log", "group", rg.Name, "error", err) } } else { - rg.Logger.Print(string(frame)) + if payloadStrReady { + rg.Logger.Print(payloadStr) + } else { + rg.Logger.Print(string(frame)) + } } // TODO(JM):Add a flag to prevent logging to multiple groups // break @@ -186,11 +198,7 @@ func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet auditLogPool.Put(auditLog) - // Preserve test expectations: - // - Close on any messenger/forwarder error - // - Close when no rule matched - // - Close when a message was forwarded (original behaviour) - if shouldClose || !matched || forwarded { + if !matched { return nil, gnet.Close } return nil, gnet.None @@ -217,6 +225,20 @@ func New(logger *slog.Logger) (*AuditServer, error) { logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})) } + viper.SetDefault("async.queue_size", 20) + viper.SetDefault("async.timeout", "5s") + + queueSize := viper.GetInt("async.queue_size") + if queueSize <= 0 { + queueSize = 20 + } + rawTimeout := viper.GetString("async.timeout") + asyncTimeout, err := time.ParseDuration(rawTimeout) + if err != nil { + asyncTimeout = 5 * time.Second + logger.Warn("Invalid async.timeout; using default", "value", rawTimeout) + } + // Load rule groups from configuration var ruleGroupConfigs []RuleGroupConfig if err := viper.UnmarshalKey("rule_groups", &ruleGroupConfigs); err != nil { @@ -225,7 +247,15 @@ func New(logger *slog.Logger) (*AuditServer, error) { } var ruleGroups []RuleGroup - for _, rgConfig := range ruleGroupConfigs { + if len(ruleGroupConfigs) == 0 { + defaultLogger := log.New(os.Stdout, "", 0) + ruleGroups = append(ruleGroups, RuleGroup{ + Name: "default", + CompiledRules: nil, + Logger: defaultLogger, + }) + } else { + for _, rgConfig := range ruleGroupConfigs { // Compile rules var compiledRules []CompiledRule for _, ruleStr := range rgConfig.Rules { @@ -252,9 +282,9 @@ func New(logger *slog.Logger) (*AuditServer, error) { var messenger messaging.Messenger switch rgConfig.Messaging.Type { case "slack": - messenger = messaging.NewSlackMessenger(rgConfig.Messaging.URL, rgConfig.Messaging.Token, rgConfig.Messaging.Channel) + messenger = messaging.NewSlackMessenger(rgConfig.Messaging.URL, rgConfig.Messaging.Token, rgConfig.Messaging.Channel, asyncTimeout) case "slack_webhook": - messenger = messaging.NewSlackWebhookMessenger(rgConfig.Messaging.WebhookURL) + messenger = messaging.NewSlackWebhookMessenger(rgConfig.Messaging.WebhookURL, asyncTimeout) default: if rgConfig.Messaging.Type != "" { logger.Error("Invalid messenger type", "type", rgConfig.Messaging.Type) @@ -270,6 +300,9 @@ func New(logger *slog.Logger) (*AuditServer, error) { logger.Error("Failed to create UDP forwarder", "error", err) return nil, fmt.Errorf("failed to create UDP forwarder: %w", err) } + if udpFwd, ok := fwd.(*forwarder.UDPForwarder); ok { + udpFwd.SetTimeout(asyncTimeout) + } } ruleGroups = append(ruleGroups, RuleGroup{ @@ -281,9 +314,15 @@ func New(logger *slog.Logger) (*AuditServer, error) { Forwarder: fwd, }) } + } - return &AuditServer{ - logger: logger, - ruleGroups: ruleGroups, - }, nil + server := &AuditServer{ + logger: logger, + ruleGroups: ruleGroups, + sideQueue: make(chan sideTask, queueSize), + asyncQueueSize: queueSize, + asyncTimeout: asyncTimeout, + } + server.startSideWorkers(defaultSideWorkers) + return server, nil } diff --git a/pkg/auditserver/server_test.go b/pkg/auditserver/server_test.go index 2b8ed04..3b13e10 100644 --- a/pkg/auditserver/server_test.go +++ b/pkg/auditserver/server_test.go @@ -60,6 +60,23 @@ func (m *MockMessenger) Send(message string) error { return nil } +type lockedBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (b *lockedBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (b *lockedBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + // mockConn is a mock implementation of gnet.Conn type mockConn struct{} @@ -160,7 +177,7 @@ func TestAuditServer_React(t *testing.T) { expectedLogs: map[string]bool{ tempDir + "/normal_operations.log": true, }, - expectAction: gnet.Close, + expectAction: gnet.None, messengerError: fmt.Errorf("failed to send message"), expectedLogMessages: []string{ "Failed to send notification", @@ -210,8 +227,8 @@ func TestAuditServer_React(t *testing.T) { } // Capture logs - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) + logBuffer := &lockedBuffer{} + logger := slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) // Create the AuditServer as, _ := New(logger) @@ -303,8 +320,8 @@ func TestNew(t *testing.T) { viper.Set("rule_groups", ruleGroupConfigs) // Capture logs - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) + logBuffer := &lockedBuffer{} + logger := slog.New(slog.NewJSONHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) server, _ := New(logger) if len(server.ruleGroups) != len(ruleGroupConfigs) { @@ -334,6 +351,100 @@ func TestNew(t *testing.T) { } } +func TestNew_DefaultRuleGroup_WhenMissingOrEmpty(t *testing.T) { + viper.Reset() + + // Missing rule_groups + server, err := New(nil) + require.NoError(t, err) + require.NotNil(t, server) + require.Len(t, server.ruleGroups, 1) + assert.Len(t, server.ruleGroups[0].CompiledRules, 0) + assert.NotNil(t, server.ruleGroups[0].Logger) + + // Empty rule_groups + viper.Reset() + viper.Set("rule_groups", []map[string]interface{}{}) + server, err = New(nil) + require.NoError(t, err) + require.Len(t, server.ruleGroups, 1) +} + +func TestNew_AsyncDefaults(t *testing.T) { + viper.Reset() + server, err := New(nil) + require.NoError(t, err) + require.NotNil(t, server) + assert.Equal(t, 20, server.asyncQueueSize) + assert.Equal(t, 5*time.Second, server.asyncTimeout) +} + +func TestSideQueue_DropsWhenFull(t *testing.T) { + viper.Reset() + viper.Set("async.queue_size", 1) + oldWorkers := defaultSideWorkers + defaultSideWorkers = 0 + defer func() { defaultSideWorkers = oldWorkers }() + + viper.Set("rule_groups", []map[string]interface{}{ + { + "name": "rg", + "rules": []string{"true"}, + "log_file": map[string]interface{}{"file_path": "/tmp/test.log", "max_size": 1}, + "messaging": map[string]interface{}{ + "type": "slack_webhook", + "webhook_url": "http://example.com", + }, + }, + }) + + srv, err := New(nil) + require.NoError(t, err) + + frame := []byte(`{"type":"request","time":"2000-01-01T00:00:00Z","auth":{},"request":{},"response":{}}`) + _, _ = srv.React(frame, nil) + _, _ = srv.React(frame, nil) + + assert.Equal(t, uint64(1), srv.sideDrops.Load()) +} + +func TestReact_AsyncMessengerCalled(t *testing.T) { + viper.Reset() + viper.Set("async.queue_size", 10) + viper.Set("rule_groups", []map[string]interface{}{ + { + "name": "rg", + "rules": []string{"true"}, + "log_file": map[string]interface{}{"file_path": "/tmp/test.log", "max_size": 1}, + "messaging": map[string]interface{}{ + "type": "slack_webhook", + "webhook_url": "http://example.com", + }, + }, + }) + + srv, err := New(nil) + require.NoError(t, err) + + called := make(chan struct{}, 1) + for i := range srv.ruleGroups { + srv.ruleGroups[i].Messenger = &MockMessenger{SendFunc: func(string) error { + called <- struct{}{} + return nil + }} + } + + frame := []byte(`{"type":"request","time":"2000-01-01T00:00:00Z","auth":{},"request":{},"response":{}}`) + _, action := srv.React(frame, nil) + assert.Equal(t, gnet.None, action) + + select { + case <-called: + case <-time.After(500 * time.Millisecond): + t.Fatalf("messenger not called") + } +} + func TestNewWithoutLogger(t *testing.T) { // Redirect stdout to capture log output oldStdout := os.Stdout @@ -695,7 +806,7 @@ func TestAuditServer_React_WithForwarding(t *testing.T) { tempDir + "/normal_operations.log": true, tempDir + "/critical_events.log": false, }, - expectAction: gnet.Close, + expectAction: gnet.None, expectedForwarded: true, expectedForwarder: 0, }, @@ -724,7 +835,7 @@ func TestAuditServer_React_WithForwarding(t *testing.T) { tempDir + "/normal_operations.log": false, tempDir + "/critical_events.log": true, }, - expectAction: gnet.Close, + expectAction: gnet.None, expectedForwarded: true, expectedForwarder: 1, }, @@ -806,24 +917,42 @@ func TestAuditServer_React_WithForwarding(t *testing.T) { type dummyMessenger struct { sendErr error + mu sync.Mutex calls int } func (d *dummyMessenger) Send(_ string) error { + d.mu.Lock() d.calls++ + d.mu.Unlock() return d.sendErr } +func (d *dummyMessenger) Calls() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.calls +} + type dummyForwarder struct { forwardErr error + mu sync.Mutex calls int } func (d *dummyForwarder) Forward(_ []byte) error { + d.mu.Lock() d.calls++ + d.mu.Unlock() return d.forwardErr } +func (d *dummyForwarder) Calls() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.calls +} + // minimal JSON frame that parses into an AuditLog func auditFrame() []byte { return []byte(`{"type":"request","time":"2000-01-01T00:00:00Z","auth":{},"request":{},"response":{}}`) @@ -886,7 +1015,7 @@ func TestReact_Branches(t *testing.T) { Writer: new(bytes.Buffer), Forwarder: &dummyForwarder{}, }, - wantAction: gnet.Close, + wantAction: gnet.None, wantFwdCalls: 1, }, { @@ -897,7 +1026,7 @@ func TestReact_Branches(t *testing.T) { Writer: new(bytes.Buffer), Forwarder: &dummyForwarder{forwardErr: errors.New("boom")}, }, - wantAction: gnet.Close, + wantAction: gnet.None, wantFwdCalls: 1, }, { @@ -908,7 +1037,7 @@ func TestReact_Branches(t *testing.T) { Writer: new(bytes.Buffer), Messenger: &dummyMessenger{sendErr: errors.New("boom")}, }, - wantAction: gnet.Close, + wantAction: gnet.None, wantMsgCalls: 1, }, { @@ -928,18 +1057,28 @@ func TestReact_Branches(t *testing.T) { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { srv := &AuditServer{ - logger: logger, + logger: logger, ruleGroups: []RuleGroup{tc.group}, + sideQueue: make(chan sideTask, 2), } + srv.startSideWorkers(1) _, act := srv.React(frame, nil) require.Equal(t, tc.wantAction, act) if dm, ok := tc.group.Messenger.(*dummyMessenger); ok { - require.Equal(t, tc.wantMsgCalls, dm.calls) + if tc.wantMsgCalls > 0 { + require.Eventually(t, func() bool { return dm.Calls() == tc.wantMsgCalls }, time.Second, 10*time.Millisecond) + } else { + require.Equal(t, tc.wantMsgCalls, dm.Calls()) + } } if df, ok := tc.group.Forwarder.(*dummyForwarder); ok { - require.Equal(t, tc.wantFwdCalls, df.calls) + if tc.wantFwdCalls > 0 { + require.Eventually(t, func() bool { return df.Calls() == tc.wantFwdCalls }, time.Second, 10*time.Millisecond) + } else { + require.Equal(t, tc.wantFwdCalls, df.Calls()) + } } }) } diff --git a/pkg/forwarder/forwarder.go b/pkg/forwarder/forwarder.go index 2f31e41..d09c27b 100644 --- a/pkg/forwarder/forwarder.go +++ b/pkg/forwarder/forwarder.go @@ -1,7 +1,10 @@ package forwarder import ( + "errors" "net" + "sync" + "time" ) // Forwarder is an interface for forwarding messages @@ -11,7 +14,9 @@ type Forwarder interface { // UDPForwarder implements the Forwarder interface for UDP type UDPForwarder struct { - conn *net.UDPConn + conn *net.UDPConn + timeout time.Duration + mu sync.Mutex } // NewUDPForwarder creates a new UDPForwarder @@ -27,8 +32,21 @@ func NewUDPForwarder(address string) (*UDPForwarder, error) { return &UDPForwarder{conn: conn}, nil } +// SetTimeout configures a per-write deadline for UDP forwarding. +func (f *UDPForwarder) SetTimeout(timeout time.Duration) { + f.timeout = timeout +} + // Forward sends the data to the UDP address func (f *UDPForwarder) Forward(data []byte) error { + if f.conn == nil { + return errors.New("udp connection is nil") + } + if f.timeout > 0 { + f.mu.Lock() + defer f.mu.Unlock() + _ = f.conn.SetWriteDeadline(time.Now().Add(f.timeout)) + } _, err := f.conn.Write(data) return err } diff --git a/pkg/forwarder/forwarder_test.go b/pkg/forwarder/forwarder_test.go index c6bbd6b..6f97925 100644 --- a/pkg/forwarder/forwarder_test.go +++ b/pkg/forwarder/forwarder_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUDPForwarder(t *testing.T) { @@ -124,6 +125,21 @@ func TestUDPForwarder_ForwardToUnreachableAddress(t *testing.T) { assert.NoError(t, err) } +func TestUDPForwarder_SetTimeout(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + + conn, err := net.ListenUDP("udp", addr) + require.NoError(t, err) + defer conn.Close() + + forwarder, err := NewUDPForwarder(conn.LocalAddr().String()) + require.NoError(t, err) + + forwarder.SetTimeout(10 * time.Millisecond) + assert.NoError(t, forwarder.Forward([]byte("msg"))) +} + func TestUDPForwarder_ConcurrentForwarding(t *testing.T) { // Start a mock UDP server addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") diff --git a/pkg/messaging/messaging.go b/pkg/messaging/messaging.go index 8b55df0..747a5c1 100644 --- a/pkg/messaging/messaging.go +++ b/pkg/messaging/messaging.go @@ -2,6 +2,9 @@ package messaging import ( "fmt" + "net/http" + "time" + "github.com/slack-go/slack" ) @@ -22,11 +25,12 @@ type SlackMessenger struct { } // NewSlackMessenger creates a new SlackMessenger -func NewSlackMessenger(serverURL, token, channel string) *SlackMessenger { +func NewSlackMessenger(serverURL, token, channel string, timeout time.Duration) *SlackMessenger { opts := []slack.Option{} if serverURL != "" { opts = append(opts, slack.OptionAPIURL(serverURL)) } + opts = append(opts, slack.OptionHTTPClient(&http.Client{Timeout: timeout})) client := slack.New(token, opts...) return &SlackMessenger{client: client, channel: channel} } @@ -43,16 +47,20 @@ func (m *SlackMessenger) Send(message string) error { // SlackWebhookMessenger implements the Messenger interface for Slack webhooks type SlackWebhookMessenger struct { webhookURL string + httpClient *http.Client } // NewSlackWebhookMessenger creates a new SlackWebhookMessenger -func NewSlackWebhookMessenger(webhookURL string) *SlackWebhookMessenger { - return &SlackWebhookMessenger{webhookURL: webhookURL} +func NewSlackWebhookMessenger(webhookURL string, timeout time.Duration) *SlackWebhookMessenger { + return &SlackWebhookMessenger{ + webhookURL: webhookURL, + httpClient: &http.Client{Timeout: timeout}, + } } // Send sends a message to Slack using a webhook func (m *SlackWebhookMessenger) Send(message string) error { - err := slack.PostWebhook(m.webhookURL, &slack.WebhookMessage{Text: message}) + err := slack.PostWebhookCustomHTTP(m.webhookURL, m.httpClient, &slack.WebhookMessage{Text: message}) if err != nil { return fmt.Errorf("failed to send message: %w", err) } diff --git a/pkg/messaging/messaging_test.go b/pkg/messaging/messaging_test.go index 1a6fe3e..6d81ff6 100644 --- a/pkg/messaging/messaging_test.go +++ b/pkg/messaging/messaging_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/slack-go/slack" "github.com/stretchr/testify/assert" @@ -26,7 +27,7 @@ func TestNewSlackMessenger(t *testing.T) { token := "test-token" channel := "test-channel" - messenger := NewSlackMessenger(serverURL, token, channel) + messenger := NewSlackMessenger(serverURL, token, channel, 0) assert.NotNil(t, messenger, "NewSlackMessenger should return a non-nil messenger") assert.Equal(t, channel, messenger.channel, "Channel should be set correctly") @@ -34,7 +35,7 @@ func TestNewSlackMessenger(t *testing.T) { _, ok := messenger.client.(*slack.Client) assert.True(t, ok, "Client should be of type *slack.Client") - emptyMessenger := NewSlackMessenger("", "", "") + emptyMessenger := NewSlackMessenger("", "", "", 0) assert.NotNil(t, emptyMessenger, "NewSlackMessenger should return a non-nil messenger even with empty inputs") assert.Empty(t, emptyMessenger.channel, "Channel should be empty") } @@ -70,7 +71,7 @@ func TestSlackWebhookMessenger_Send(t *testing.T) { })) defer server.Close() - messenger := NewSlackWebhookMessenger(server.URL) + messenger := NewSlackWebhookMessenger(server.URL, 0) err := messenger.Send("Test message") assert.NoError(t, err) } @@ -81,31 +82,56 @@ func TestSlackWebhookMessenger_SendError(t *testing.T) { })) defer server.Close() - messenger := NewSlackWebhookMessenger(server.URL) + messenger := NewSlackWebhookMessenger(server.URL, 0) err := messenger.Send("Test message") assert.Error(t, err) } func TestNewSlackWebhookMessenger(t *testing.T) { webhookURL := "https://slack.example.com/hooks/abc123" - messenger := NewSlackWebhookMessenger(webhookURL) + messenger := NewSlackWebhookMessenger(webhookURL, 0) assert.NotNil(t, messenger, "NewSlackWebhookMessenger should return a non-nil messenger") assert.Equal(t, webhookURL, messenger.webhookURL, "Webhook URL should be set correctly") - emptyMessenger := NewSlackWebhookMessenger("") + emptyMessenger := NewSlackWebhookMessenger("", 0) assert.NotNil(t, emptyMessenger, "NewSlackWebhookMessenger should return a non-nil messenger even with empty webhook URL") assert.Empty(t, emptyMessenger.webhookURL, "Webhook URL should be empty") } func TestSlackWebhookMessenger_SendInvalidURL(t *testing.T) { - messenger := NewSlackWebhookMessenger("http://[::1]:NamedPort") + messenger := NewSlackWebhookMessenger("http://[::1]:NamedPort", 0) err := messenger.Send("Test message") assert.Error(t, err) } func TestSlackWebhookMessenger_SendEmptyURL(t *testing.T) { - messenger := NewSlackWebhookMessenger("") + messenger := NewSlackWebhookMessenger("", 0) err := messenger.Send("Test message") assert.Error(t, err) } + +func TestSlackWebhookMessenger_Timeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + m := NewSlackWebhookMessenger(srv.URL, 10*time.Millisecond) + err := m.Send("msg") + assert.Error(t, err) +} + +func TestSlackMessenger_Timeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + m := NewSlackMessenger(srv.URL, "token", "chan", 10*time.Millisecond) + err := m.Send("msg") + assert.Error(t, err) +}