Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ go.work.sum

# ignore the log directory
configs/development/logs/
.worktrees/
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ Once you have built the project, you can run the `vault-audit-filter` executable
audit_address: "127.0.0.1:1269"
audit_description: "Vault Audit Filter Device"

async:
queue_size: 20
timeout: 5s

rule_groups:
- name: "normal_operations"
rules:
Expand Down Expand Up @@ -129,6 +133,10 @@ Once you have built the project, you can run the `vault-audit-filter` executable
- `messaging.token`: The bot token for Slack (when using "slack" type).
- `messaging.channel`: The channel ID for Slack messages (when using "slack" type).

- **Async Settings**:
- `async.queue_size`: Bounded queue length for async side effects (drop on full).
- `async.timeout`: Timeout for Slack API/webhook and forwarding operations.

### Rule Syntax

Rules are written using the `expr` language, a simple and safe expression language for Go. Rules can be based on the following properties of audit logs:
Expand Down
47 changes: 47 additions & 0 deletions pkg/auditserver/async.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package auditserver

import (
"github.com/ncode/vault-audit-filter/pkg/forwarder"
"github.com/ncode/vault-audit-filter/pkg/messaging"
)

var defaultSideWorkers = 2

type sideTask struct {
payload []byte
payloadStr string
messenger messaging.Messenger
forwarder forwarder.Forwarder
}

func (as *AuditServer) enqueueSide(task sideTask) bool {
select {
case as.sideQueue <- task:
return true
default:
as.sideDrops.Add(1)
return false
}
}

func (as *AuditServer) startSideWorkers(n int) {
if n <= 0 {
return
}
for i := 0; i < n; i++ {
go func() {
for task := range as.sideQueue {
if task.messenger != nil {
if err := task.messenger.Send(task.payloadStr); err != nil {
as.logger.Error("Failed to send notification", "error", err)
}
}
if task.forwarder != nil {
if err := task.forwarder.Forward(task.payload); err != nil {
as.logger.Error("Failed to forward message", "error", err)
}
}
}
}()
}
}
43 changes: 43 additions & 0 deletions pkg/auditserver/bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package auditserver

import (
"io"
"log/slog"
"testing"

"github.com/expr-lang/expr"
"github.com/spf13/viper"
)

func BenchmarkReact(b *testing.B) {
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelInfo}))
viper.Reset()
viper.Set("rule_groups", []map[string]interface{}{
{
"name": "rg",
"rules": []string{"true"},
"log_file": map[string]interface{}{
"file_path": "/tmp/test.log",
"max_size": 1,
},
},
})
server, _ := New(logger)
frame := []byte(`{"type":"request","time":"2000-01-01T00:00:00Z","auth":{},"request":{},"response":{}}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
server.React(frame, nil)
}
}

func BenchmarkShouldLog(b *testing.B) {
p, _ := expr.Compile("true", expr.Env(&AuditLog{}))
rg := &RuleGroup{CompiledRules: []CompiledRule{{Program: p}}}
al := &AuditLog{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = rg.shouldLog(al)
}
}
97 changes: 68 additions & 29 deletions pkg/auditserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log/slog"
"os"
"sync"
"sync/atomic"
"time"

"github.com/expr-lang/expr"
Expand Down Expand Up @@ -129,8 +130,12 @@ type LogFileConfig struct {

type AuditServer struct {
*gnet.EventServer
logger *slog.Logger
ruleGroups []RuleGroup
logger *slog.Logger
ruleGroups []RuleGroup
sideQueue chan sideTask
sideDrops atomic.Uint64
asyncQueueSize int
asyncTimeout time.Duration
}

func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet.Action) {
Expand All @@ -145,30 +150,33 @@ func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet
return nil, gnet.Close
}

shouldClose := false
matched := false
forwarded := false
var payload []byte
var payloadStr string
payloadReady := false
payloadStrReady := false

// Check each rule group
for _, rg := range as.ruleGroups {
if rg.shouldLog(auditLog) {
matched = true
as.logger.Debug("Matched rule group", "group", rg.Name)

// Send notification if messenger is configured
if rg.Messenger != nil {
if err := rg.Messenger.Send(string(frame)); err != nil {
as.logger.Error("Failed to send notification", "error", err)
shouldClose = true
if rg.Messenger != nil || rg.Forwarder != nil {
if !payloadReady {
payload = append([]byte(nil), frame...)
payloadReady = true
}
}

if rg.Forwarder != nil {
if err := rg.Forwarder.Forward(frame); err != nil {
as.logger.Error("Failed to forward message", "error", err)
shouldClose = true
if rg.Messenger != nil && !payloadStrReady {
payloadStr = string(payload)
payloadStrReady = true
}
forwarded = true
_ = as.enqueueSide(sideTask{
payload: payload,
payloadStr: payloadStr,
messenger: rg.Messenger,
forwarder: rg.Forwarder,
})
}

// zero‑copy write to log when possible
Expand All @@ -177,7 +185,11 @@ func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet
as.logger.Error("Failed to write audit log", "group", rg.Name, "error", err)
}
} else {
rg.Logger.Print(string(frame))
if payloadStrReady {
rg.Logger.Print(payloadStr)
} else {
rg.Logger.Print(string(frame))
}
}
// TODO(JM):Add a flag to prevent logging to multiple groups
// break
Expand All @@ -186,11 +198,7 @@ func (as *AuditServer) React(frame []byte, c gnet.Conn) (out []byte, action gnet

auditLogPool.Put(auditLog)

// Preserve test expectations:
// - Close on any messenger/forwarder error
// - Close when no rule matched
// - Close when a message was forwarded (original behaviour)
if shouldClose || !matched || forwarded {
if !matched {
return nil, gnet.Close
}
return nil, gnet.None
Expand All @@ -217,6 +225,20 @@ func New(logger *slog.Logger) (*AuditServer, error) {
logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))
}

viper.SetDefault("async.queue_size", 20)
viper.SetDefault("async.timeout", "5s")

queueSize := viper.GetInt("async.queue_size")
if queueSize <= 0 {
queueSize = 20
}
rawTimeout := viper.GetString("async.timeout")
asyncTimeout, err := time.ParseDuration(rawTimeout)
if err != nil {
asyncTimeout = 5 * time.Second
logger.Warn("Invalid async.timeout; using default", "value", rawTimeout)
}

// Load rule groups from configuration
var ruleGroupConfigs []RuleGroupConfig
if err := viper.UnmarshalKey("rule_groups", &ruleGroupConfigs); err != nil {
Expand All @@ -225,7 +247,15 @@ func New(logger *slog.Logger) (*AuditServer, error) {
}

var ruleGroups []RuleGroup
for _, rgConfig := range ruleGroupConfigs {
if len(ruleGroupConfigs) == 0 {
defaultLogger := log.New(os.Stdout, "", 0)
ruleGroups = append(ruleGroups, RuleGroup{
Name: "default",
CompiledRules: nil,
Logger: defaultLogger,
})
} else {
for _, rgConfig := range ruleGroupConfigs {
// Compile rules
var compiledRules []CompiledRule
for _, ruleStr := range rgConfig.Rules {
Expand All @@ -252,9 +282,9 @@ func New(logger *slog.Logger) (*AuditServer, error) {
var messenger messaging.Messenger
switch rgConfig.Messaging.Type {
case "slack":
messenger = messaging.NewSlackMessenger(rgConfig.Messaging.URL, rgConfig.Messaging.Token, rgConfig.Messaging.Channel)
messenger = messaging.NewSlackMessenger(rgConfig.Messaging.URL, rgConfig.Messaging.Token, rgConfig.Messaging.Channel, asyncTimeout)
case "slack_webhook":
messenger = messaging.NewSlackWebhookMessenger(rgConfig.Messaging.WebhookURL)
messenger = messaging.NewSlackWebhookMessenger(rgConfig.Messaging.WebhookURL, asyncTimeout)
default:
if rgConfig.Messaging.Type != "" {
logger.Error("Invalid messenger type", "type", rgConfig.Messaging.Type)
Expand All @@ -270,6 +300,9 @@ func New(logger *slog.Logger) (*AuditServer, error) {
logger.Error("Failed to create UDP forwarder", "error", err)
return nil, fmt.Errorf("failed to create UDP forwarder: %w", err)
}
if udpFwd, ok := fwd.(*forwarder.UDPForwarder); ok {
udpFwd.SetTimeout(asyncTimeout)
}
}

ruleGroups = append(ruleGroups, RuleGroup{
Expand All @@ -281,9 +314,15 @@ func New(logger *slog.Logger) (*AuditServer, error) {
Forwarder: fwd,
})
}
}

return &AuditServer{
logger: logger,
ruleGroups: ruleGroups,
}, nil
server := &AuditServer{
logger: logger,
ruleGroups: ruleGroups,
sideQueue: make(chan sideTask, queueSize),
asyncQueueSize: queueSize,
asyncTimeout: asyncTimeout,
}
server.startSideWorkers(defaultSideWorkers)
return server, nil
}
Loading