From 38444d78d40bce9c0dbb0f3bcd77b5d71996c05f Mon Sep 17 00:00:00 2001 From: rabbitstack Date: Tue, 12 May 2026 18:33:10 +0200 Subject: [PATCH] perf(filter): Introduce valuer cache We're doing redundant field extraction. If 20 rules match CreateProcess event type, we extract the same fields 20 times per event. Valuer cahces introduces a per-event extraction cache that is populated once and reused across all rules significantly improving performance and reducing allocation rate. --- internal/etw/consumer.go | 4 +- pkg/cap/reader_windows.go | 2 +- pkg/filament/filament_test.go | 11 +- pkg/filter/fields/id.go | 75 +++++++++++ pkg/filter/filter.go | 93 +++++++------ pkg/filter/filter_test.go | 40 +++--- pkg/filter/valuer.go | 73 ++++++++++ pkg/filter/valuer_test.go | 241 ++++++++++++++++++++++++++++++++++ pkg/rules/engine.go | 19 ++- pkg/rules/sequence.go | 13 +- pkg/rules/sequence_test.go | 100 ++++++++------ 11 files changed, 549 insertions(+), 122 deletions(-) create mode 100644 pkg/filter/fields/id.go create mode 100644 pkg/filter/valuer.go create mode 100644 pkg/filter/valuer_test.go diff --git a/internal/etw/consumer.go b/internal/etw/consumer.go index 03787376a..a54c5269d 100644 --- a/internal/etw/consumer.go +++ b/internal/etw/consumer.go @@ -131,9 +131,11 @@ func (c *Consumer) ProcessEvent(ev *etw.EventRecord) error { eventsExcluded.Add(1) return nil } - if c.filter != nil && !evt.IsStackWalk() && !c.filter.Run(evt) { + + if c.filter != nil && !evt.IsStackWalk() && !c.filter.Eval(evt) { return nil } + // Increment sequence if !evt.IsState() { c.sequencer.Increment() diff --git a/pkg/cap/reader_windows.go b/pkg/cap/reader_windows.go index d7982e10f..8d992bd0d 100644 --- a/pkg/cap/reader_windows.go +++ b/pkg/cap/reader_windows.go @@ -169,7 +169,7 @@ func (r *reader) read(evt *event.Event, eventsc chan *event.Event) { if evt.Type.OnlyState() { return } - if r.filter != nil && !r.filter.Run(evt) { + if r.filter != nil && !r.filter.Eval(evt) { capDroppedByFilter.Add(1) return } diff --git a/pkg/filament/filament_test.go b/pkg/filament/filament_test.go index 0cb394fc8..0e457e361 100644 --- a/pkg/filament/filament_test.go +++ b/pkg/filament/filament_test.go @@ -24,15 +24,16 @@ package filament import ( "bufio" "bytes" + "net" + "strings" + "testing" + "time" + "github.com/rabbitstack/fibratus/pkg/config" "github.com/rabbitstack/fibratus/pkg/event" "github.com/rabbitstack/fibratus/pkg/event/params" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "net" - "strings" - "testing" - "time" ) func init() { @@ -117,5 +118,5 @@ func TestFilamentFilter(t *testing.T) { Name: "CreateProcess", } - require.True(t, filament.Filter().Run(evt)) + require.True(t, filament.Filter().Eval(evt)) } diff --git a/pkg/filter/fields/id.go b/pkg/filter/fields/id.go new file mode 100644 index 000000000..68e086dd8 --- /dev/null +++ b/pkg/filter/fields/id.go @@ -0,0 +1,75 @@ +/* + * Copyright 2016-present by Nedim Sabic Sabic + * https://www.fibratus.io + * All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fields + +const MaxFieldID int16 = 23 + +// ID returns the filter index used by the valuer cache. +// Note that not all fields need to provide the identifier. +func (f Field) ID() int16 { + switch f { + case EvtName: + return 0 + case EvtPID: + return 1 + case RegistryPath: + return 2 + case FilePath: + return 3 + case FileExtension: + return 4 + case PsExe: + return 5 + case PsName: + return 6 + case PsPid: + return 7 + case FileOperation: + return 8 + case FileStatus: + return 9 + case RegistryStatus: + return 10 + case ModuleName: + return 11 + case ModulePath: + return 12 + case ModuleSignatureExists, DllSignatureExists: + return 13 + case ModuleSignatureTrusted, DllSignatureTrusted: + return 14 + case ThreadCallstackSummary: + return 15 + case ThreadCallstackModules: + return 16 + case ThreadCallstackSymbols: + return 17 + case PsParentName: + return 18 + case PsSID: + return 19 + case PsCmdline: + return 20 + case PsTokenIntegrityLevel: + return 21 + case PsAccessMaskNames: + return MaxFieldID - 1 + } + return -1 +} diff --git a/pkg/filter/filter.go b/pkg/filter/filter.go index eb6eb2e5e..3a05dcec5 100644 --- a/pkg/filter/filter.go +++ b/pkg/filter/filter.go @@ -26,7 +26,6 @@ import ( "regexp" "strconv" "strings" - "sync" errs "github.com/rabbitstack/fibratus/pkg/errors" "github.com/rabbitstack/fibratus/pkg/event" @@ -47,13 +46,21 @@ var ( type Filter interface { // Compile compiles the filter by parsing the sequence/expression. Compile() error - // Run runs a filter with a single expression. The return value decides - // if the incoming event has successfully matched the filter expression. - Run(evt *event.Event) bool - // RunSequence runs a filter with sequence expressions. Sequence rules depend - // on the state machine transitions and partial matches to decide whether the - // rule is fired. - RunSequence(evt *event.Event, seqID int, partials map[int][]*event.Event, rawMatch bool) bool + // Eval evaluates the event against filter expression. Returns true if the filter + // has matched against the event, or false othwerise. Creates the valuer cache within + // the lifetime of the method call. + Eval(evt *event.Event) bool + // EvalWithValuer evaluates the event against filter. Returns true if the filter + // has matched against the event, or false othwerise. + // The valuer cache is acquired before the evaluation stage and provides a fast + // access to extracted field values. + EvalWithValuer(evt *event.Event, valuer *ValuerCache) bool + // EvalSequence evalutes the event against sequence expresions. Sequence rules + // depend on the state machine transitions and partial matches to decide whether + // the rule is fired. + // The valuer cache is acquired before the evaluation stage and provides a fast + // access to extracted field values. + EvalSequence(evt *event.Event, valuer *ValuerCache, seqID int, partials map[int][]*event.Event, rawMatch bool) bool // GetStringFields returns field names mapped to their string values. GetStringFields() map[fields.Field][]string // GetFields returns all fields used in the filter expression. @@ -243,11 +250,17 @@ func (f *filter) Compile() error { return f.checkBoundRefs() } -func (f *filter) Run(e *event.Event) bool { +func (f *filter) Eval(e *event.Event) bool { + valuer := AcquireValuerCache() + defer valuer.Release() + return f.EvalWithValuer(e, valuer) +} + +func (f *filter) EvalWithValuer(e *event.Event, cache *ValuerCache) bool { if f.expr == nil { return false } - return ql.Eval(f.expr, f.mapValuer(e), f.hasFunctions) + return ql.Eval(f.expr, f.mapValuer(e, cache), f.hasFunctions) } // evalBoundSequence evaluates the sequence with bound fields @@ -380,7 +393,7 @@ func (f *filter) evalSequence( return match } -func (f *filter) RunSequence(e *event.Event, seqID int, partials map[int][]*event.Event, rawMatch bool) bool { +func (f *filter) EvalSequence(e *event.Event, valuerCache *ValuerCache, seqID int, partials map[int][]*event.Event, rawMatch bool) bool { if f.seq == nil { return false } @@ -388,8 +401,7 @@ func (f *filter) RunSequence(e *event.Event, seqID int, partials map[int][]*even if seqID > nseqs-1 { return false } - valuer := f.mapValuer(e) - defer valuerPool.Put(valuer) + valuer := f.mapValuer(e, valuerCache) expr := f.seq.Expressions[seqID] if rawMatch { @@ -482,39 +494,38 @@ func InterpolateFields(s string, evts []*event.Event) string { return r } -var valuerPool = sync.Pool{ - New: func() any { - return make(map[string]any) - }, -} - // mapValuer for each field present in the AST, we run the -// accessors and extract the field values that are -// supplied to the valuer. The valuer feeds the -// expression with correct values. -func (f *filter) mapValuer(evt *event.Event) map[string]any { - valuer := valuerPool.Get().(map[string]any) +// accessors and extract the field values that are supplied +// to the valuer. The valuer feeds the expression with correct +// values. If the field value is present in the valuer cache then +// we directly populate the valuer for the field. +func (f *filter) mapValuer(evt *event.Event, valuerCache *ValuerCache) map[string]any { for _, field := range f.fields { - for _, accessor := range f.accessors { - if !accessor.IsFieldAccessible(evt) { - continue - } - v, err := accessor.Get(field, evt) - if v == nil || err != nil { - if v == nil { - valuer[field.String()] = defaultAccessorValue(field) - } - if err != nil && !errs.IsParamNotFound(err) { - valuer[field.String()] = defaultAccessorValue(field) - accessorErrors.Add(err.Error(), 1) - } - continue + valuerCache.populateValuer(field, func() any { + return f.extractField(field, evt) + }) + } + return valuerCache.valuer +} + +// extractField extracts the field value from the accessor. +func (f *filter) extractField(field Field, evt *event.Event) any { + for _, accessor := range f.accessors { + if !accessor.IsFieldAccessible(evt) { + continue + } + v, err := accessor.Get(field, evt) + if err != nil { + if !errs.IsParamNotFound(err) { + accessorErrors.Add(err.Error(), 1) } - valuer[field.String()] = v - break + return defaultAccessorValue(field) + } + if v != nil { + return v } } - return valuer + return defaultAccessorValue(field) } // addField appends a new field to the filter fields list. diff --git a/pkg/filter/filter_test.go b/pkg/filter/filter_test.go index 7c7fdaace..8b107cda0 100644 --- a/pkg/filter/filter_test.go +++ b/pkg/filter/filter_test.go @@ -360,7 +360,7 @@ func TestProcFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q ps filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -381,7 +381,7 @@ func TestProcFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt1) + matches := f.Eval(evt1) if matches != tt.matches { t.Errorf("%d. %q ps filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -401,7 +401,7 @@ func TestProcFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt2) + matches := f.Eval(evt2) if matches != tt.matches { t.Errorf("%d. %q ps filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -533,7 +533,7 @@ func TestThreadFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q thread filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -605,7 +605,7 @@ func TestThreadFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q thread filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -686,7 +686,7 @@ func TestFileFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q file filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -772,7 +772,7 @@ func TestFileInfoFilter(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, tt.matches, f.Run(tt.e)) + assert.Equal(t, tt.matches, f.Eval(tt.e)) }) } } @@ -848,7 +848,7 @@ func TestEventFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q evt filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -901,7 +901,7 @@ func TestNetFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q net filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -938,7 +938,7 @@ func TestNetFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt1) + matches := f.Eval(evt1) if matches != tt.matches { t.Errorf("%d. %q net filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -980,7 +980,7 @@ func TestRegistryFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q registry filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1041,7 +1041,7 @@ func TestModuleFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(e1) + matches := f.Eval(e1) if matches != tt.matches { t.Errorf("%d. %q module filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1101,7 +1101,7 @@ func TestModuleFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(e2) + matches := f.Eval(e2) if matches != tt.matches { t.Errorf("%d. %q filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1142,7 +1142,7 @@ func TestModuleFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(e3) + matches := f.Eval(e3) if matches != tt.matches { t.Errorf("%d. %q module filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1194,7 +1194,7 @@ func TestPEFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q pe filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1242,7 +1242,7 @@ func TestLazyPEFilter(t *testing.T) { t.Fatal(err) } require.Nil(t, evt.PS.PE) - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q pe lazy filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1293,7 +1293,7 @@ func TestMemFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q mem filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1336,7 +1336,7 @@ func TestDNSFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(evt) + matches := f.Eval(evt) if matches != tt.matches { t.Errorf("%d. %q dns filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1401,7 +1401,7 @@ func TestThreadpoolFilter(t *testing.T) { if err != nil { t.Fatal(err) } - matches := f.Run(e) + matches := f.Eval(e) if matches != tt.matches { t.Errorf("%d. %q threadpool filter mismatch: exp=%t got=%t", i, tt.filter, tt.matches, matches) } @@ -1562,7 +1562,7 @@ func BenchmarkFilterRun(b *testing.B) { } for i := 0; i < b.N; i++ { - f.Run(evt) + f.Eval(evt) } } diff --git a/pkg/filter/valuer.go b/pkg/filter/valuer.go new file mode 100644 index 000000000..2f5d57f8f --- /dev/null +++ b/pkg/filter/valuer.go @@ -0,0 +1,73 @@ +/* + * Copyright 2021-present by Nedim Sabic Sabic + * https://www.fibratus.io + * All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package filter + +import ( + "sync" + + "github.com/rabbitstack/fibratus/pkg/filter/fields" + "github.com/rabbitstack/fibratus/pkg/filter/ql" +) + +// ValuerCache caches extracted field values for a single event's lifetime. +type ValuerCache struct { + slots [fields.MaxFieldID]any + valuer ql.MapValuer +} + +var valuerCachePool = sync.Pool{ + New: func() any { + return &ValuerCache{ + valuer: make(ql.MapValuer, 8), // pre-allocate buckets + } + }, +} + +func AcquireValuerCache() *ValuerCache { + return valuerCachePool.Get().(*ValuerCache) +} + +func (c *ValuerCache) Release() { + c.slots = [fields.MaxFieldID]any{} + clear(c.valuer) + valuerCachePool.Put(c) +} + +func (c *ValuerCache) populateValuer(f Field, extract func() any) { + id := f.Name.ID() + if id == -1 { + // if the field doesn't allow fast id lookup + // extract the value and cache inside valuer + n := f.String() + if _, ok := c.valuer[n]; !ok { + c.valuer[n] = extract() + } + return + } + + // field value is already cached, skip + v := c.slots[id] + if v != nil { + return + } + + // extract the value and cache + v = extract() + c.slots[id], c.valuer[f.String()] = v, v +} diff --git a/pkg/filter/valuer_test.go b/pkg/filter/valuer_test.go new file mode 100644 index 000000000..c857eaf38 --- /dev/null +++ b/pkg/filter/valuer_test.go @@ -0,0 +1,241 @@ +/* + * Copyright 2021-present by Nedim Sabic Sabic + * https://www.fibratus.io + * All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package filter + +import ( + "testing" + + "github.com/rabbitstack/fibratus/pkg/filter/fields" + "github.com/stretchr/testify/assert" +) + +var dontCallValuerFunc = func() any { return "should-not-be-called" } +var extractValueFunc = func() any { return "explorer.exe" } + +func TestValuerCacheHit(t *testing.T) { + c := AcquireValuerCache() + defer c.Release() + + calls := 0 + extract := func() any { + calls++ + return "explorer.exe" + } + + f := Field{Name: fields.PsName, Value: fields.PsName.String()} + + c.populateValuer(f, extract) + c.populateValuer(f, extract) + + assert.Equal(t, "explorer.exe", c.valuer[f.String()]) + assert.Equal(t, 1, calls, "extract must be called exactly once on repeated access") +} + +func TestValuerCacheMiss(t *testing.T) { + c := AcquireValuerCache() + defer c.Release() + + f := Field{Name: fields.PsName, Value: fields.PsName.String()} + c.populateValuer(f, func() any { return "svchost.exe" }) + + assert.Equal(t, "svchost.exe", c.valuer[f.String()]) +} + +func TestValuerCacheDistinctFields(t *testing.T) { + c := AcquireValuerCache() + defer c.Release() + + f := Field{Name: fields.PsName, Value: fields.PsName.String()} + f1 := Field{Name: fields.FilePath, Value: fields.FilePath.String()} + + c.populateValuer(f, extractValueFunc) + c.populateValuer(f1, func() any { return `C:\Windows\System32\cmd.exe` }) + + // populate again to verify no overwrite + c.populateValuer(f, dontCallValuerFunc) + c.populateValuer(f1, dontCallValuerFunc) + + assert.Equal(t, "explorer.exe", c.valuer[f.String()]) + assert.Equal(t, `C:\Windows\System32\cmd.exe`, c.valuer[f1.String()]) +} + +func TestValuerCacheNilValue(t *testing.T) { + c := AcquireValuerCache() + defer c.Release() + + calls := 0 + c.populateValuer(Field{Name: fields.PsName, Value: fields.PsName.String()}, func() any { + calls++ + return nil + }) + + // nil is not cached in slots (v != nil check), so extract will be called again + c.populateValuer(Field{Name: fields.PsName, Value: fields.PsName.String()}, func() any { + calls++ + return nil + }) + + assert.Equal(t, 2, calls, "nil values are not cached, extract is called on every invocation") +} + +func TestValuerCacheReset(t *testing.T) { + c := AcquireValuerCache() + + c.populateValuer(Field{Name: fields.PsName, Value: fields.PsName.String()}, extractValueFunc) + c.Release() + + // simulate pool returning the same instance + c2 := AcquireValuerCache() + defer c2.Release() + + calls := 0 + c2.populateValuer(Field{Name: fields.PsName, Value: fields.PsName.String()}, func() any { + calls++ + return "notepad.exe" + }) + + assert.Equal(t, "notepad.exe", c2.valuer[fields.PsName.String()]) + assert.Equal(t, 1, calls, "slot must be cleared after Release") +} + +func TestValuerCacheExtractCalledOnceAcrossRules(t *testing.T) { + c := AcquireValuerCache() + defer c.Release() + + calls := 0 + extract := func() any { + calls++ + return uint32(1234) + } + + // simulate 20 rules all requesting the same field + for range 20 { + c.populateValuer(Field{Name: fields.PsPid, Value: fields.PsPid.String()}, extract) + } + + assert.Equal(t, uint32(1234), c.valuer[fields.PsPid.String()]) + assert.Equal(t, 1, calls, "extract must be called once regardless of rule count") +} + +func TestValuerCacheAllSlotsIndependent(t *testing.T) { + c := AcquireValuerCache() + defer c.Release() + + want := map[Field]any{ + {Name: fields.PsName, Value: fields.PsName.String()}: "explorer.exe", + {Name: fields.PsPid, Value: fields.PsPid.String()}: uint32(4), + {Name: fields.FilePath, Value: fields.FilePath.String()}: `C:\Windows\System32\cmd.exe`, + } + + for f, v := range want { + val := v + c.populateValuer(f, func() any { return val }) + } + + for f, expected := range want { + assert.Equal(t, expected, c.valuer[f.String()], "field %v", f) + } +} + +func TestValuerCachePoolReuse(t *testing.T) { + for i := range 10 { + c := AcquireValuerCache() + + calls := 0 + c.populateValuer(Field{Name: fields.PsName, Value: fields.PsName.String()}, func() any { + calls++ + return i + }) + + assert.Equal(t, i, c.valuer[fields.PsName.String()]) + assert.Equal(t, 1, calls, "event %d: stale slot from previous cycle", i) + + c.Release() + } +} + +func TestValuerCacheFieldWithoutID(t *testing.T) { + c := AcquireValuerCache() + defer c.Release() + + // fields with id == -1 always call extract, no caching + calls := 0 + extract := func() any { + calls++ + return "value" + } + + c.populateValuer(Field{Name: fields.HandleID, Value: fields.HandleID.String()}, extract) + c.populateValuer(Field{Name: fields.HandleName, Value: fields.HandleName.String()}, extract) + + assert.Equal(t, 2, calls, "unknown fields (id == -1) must not be cached") + + c.populateValuer(Field{Name: fields.HandleID, Value: fields.HandleID.String()}, dontCallValuerFunc) + c.populateValuer(Field{Name: fields.HandleName, Value: fields.HandleName.String()}, dontCallValuerFunc) + + // now the fields should be cached + assert.Equal(t, "value", c.valuer[fields.HandleID.String()]) + assert.Equal(t, "value", c.valuer[fields.HandleName.String()]) +} + +func BenchmarkValuerCacheHit(b *testing.B) { + c := AcquireValuerCache() + defer c.Release() + + c.populateValuer(Field{Name: fields.PsName, Value: fields.PsName.String()}, extractValueFunc) + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + c.populateValuer(Field{Name: fields.PsName, Value: fields.PsName.String()}, dontCallValuerFunc) + } +} + +func BenchmarkValuerCacheMiss(b *testing.B) { + b.ReportAllocs() + + for range b.N { + c := AcquireValuerCache() + c.populateValuer(Field{Name: fields.PsName, Value: fields.PsName.String()}, extractValueFunc) + c.Release() + } +} + +func BenchmarkValuerCacheFullEvent(b *testing.B) { + // simulates 20 rules each requesting 3 fields on every event + fieldsUnderTest := []Field{ + {Name: fields.PsName, Value: fields.PsName.String()}, + {Name: fields.PsPid, Value: fields.PsPid.String()}, + {Name: fields.FilePath, Value: fields.FilePath.String()}, + } + + b.ReportAllocs() + + for range b.N { + c := AcquireValuerCache() + for range 20 { + for _, f := range fieldsUnderTest { + field := f + c.populateValuer(field, func() any { return "value" }) + } + } + c.Release() + } +} diff --git a/pkg/rules/engine.go b/pkg/rules/engine.go index a3803d301..cd0f4974a 100644 --- a/pkg/rules/engine.go +++ b/pkg/rules/engine.go @@ -21,6 +21,9 @@ package rules import ( "expvar" "fmt" + "sync" + "time" + "github.com/rabbitstack/fibratus/pkg/config" "github.com/rabbitstack/fibratus/pkg/event" "github.com/rabbitstack/fibratus/pkg/filter" @@ -28,8 +31,6 @@ import ( "github.com/rabbitstack/fibratus/pkg/ps" "github.com/rabbitstack/fibratus/pkg/rules/action" log "github.com/sirupsen/logrus" - "sync" - "time" ) // RuleMatchFunc is rule match function definition. It accepts @@ -121,11 +122,11 @@ func (f *compiledFilter) isSequence() bool { return f.ss != nil } -func (f *compiledFilter) run(e *event.Event) bool { +func (f *compiledFilter) eval(e *event.Event, valuer *filter.ValuerCache) bool { if f.ss != nil { - return f.ss.runSequence(e) + return f.ss.evalSequence(e, valuer) } - return f.filter.Run(e) + return f.filter.EvalWithValuer(e, valuer) } // NewEngine builds a fresh rules engine instance. @@ -235,9 +236,14 @@ func (e *Engine) ProcessEvent(evt *event.Event) (bool, error) { filters := e.filters.collect(evt) + // acquire valuer cache + valuer := filter.AcquireValuerCache() + defer valuer.Release() + + // assert event against compiled ruleset var matches bool for _, f := range filters { - match := f.run(evt) + match := f.eval(evt, valuer) if !match { continue } @@ -274,6 +280,7 @@ func (e *Engine) processActions() error { defer e.clearMatches() e.mmu.Lock() defer e.mmu.Unlock() + for _, m := range e.matches { f, evts := m.ctx.Filter, m.ctx.Events filterMatches.Add(f.Name, 1) diff --git a/pkg/rules/sequence.go b/pkg/rules/sequence.go index 97887e116..dfc3d05f7 100644 --- a/pkg/rules/sequence.go +++ b/pkg/rules/sequence.go @@ -436,7 +436,7 @@ func (s *sequenceState) scheduleMaxSpanDeadline(seqID fsm.State, maxSpan time.Du s.spanDeadlines[seqID] = t } -func (s *sequenceState) runSequence(e *event.Event) bool { +func (s *sequenceState) evalSequence(e *event.Event, v *filter.ValuerCache) bool { for i, expr := range s.seq.Expressions { // only try to evaluate the expression // if upstream expressions have matched @@ -453,7 +453,7 @@ func (s *sequenceState) runSequence(e *event.Event) bool { // against the current event, mark it as // out-of-order and store in partials list s.mu.RLock() - ok := expr.IsEvaluable(e) && s.filter.RunSequence(e, i, s.partials, true) + ok := expr.IsEvaluable(e) && s.filter.EvalSequence(e, v, i, s.partials, true) s.mu.RUnlock() if ok { s.addPartial(i, e, true) @@ -462,9 +462,8 @@ func (s *sequenceState) runSequence(e *event.Event) bool { } s.mu.RLock() - matches := expr.IsEvaluable(e) && s.filter.RunSequence(e, i, s.partials, false) + matches := expr.IsEvaluable(e) && s.filter.EvalSequence(e, v, i, s.partials, false) s.mu.RUnlock() - if !matches { continue } @@ -500,10 +499,14 @@ func (s *sequenceState) runSequence(e *event.Event) bool { if evt.PS == nil { _, evt.PS = s.psnap.Find(evt.PID) } - matches = s.filter.RunSequence(evt, seqID, s.partials, false) + + v := filter.AcquireValuerCache() + defer v.Release() + matches = s.filter.EvalSequence(evt, v, seqID, s.partials, false) if !matches { continue } + // transition the state machine err := s.matchTransition(seqID, evt) if err != nil { diff --git a/pkg/rules/sequence_test.go b/pkg/rules/sequence_test.go index bf3966157..eb6397f55 100644 --- a/pkg/rules/sequence_test.go +++ b/pkg/rules/sequence_test.go @@ -37,6 +37,13 @@ import ( "golang.org/x/sys/windows/registry" ) +func runSequence(ss *sequenceState, e *event.Event) bool { + valuer := filter.AcquireValuerCache() + defer valuer.Release() + matches := ss.evalSequence(e, valuer) + return matches +} + func TestSequenceState(t *testing.T) { log.SetLevel(log.DebugLevel) @@ -264,7 +271,7 @@ func TestSimpleSequence(t *testing.T) { for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { for idx, e := range tt.evts { - assert.Equal(t, tt.matches[idx], ss.runSequence(e)) + assert.Equal(t, tt.matches[idx], runSequence(ss, e)) } }) } @@ -318,8 +325,8 @@ func TestSimpleSequenceMultiplePartials(t *testing.T) { }, Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e1)) - require.False(t, ss.runSequence(e2)) + require.False(t, runSequence(ss, e1)) + require.False(t, runSequence(ss, e2)) } // expression matched multiple partials @@ -365,11 +372,11 @@ func TestSimpleSequenceMultiplePartials(t *testing.T) { Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e1)) + require.False(t, runSequence(ss, e1)) // expression matched the partial that satisfies the sequence link assert.Len(t, ss.partials[0], 6) assert.Len(t, ss.partials[1], 0) - require.True(t, ss.runSequence(e2)) + require.True(t, runSequence(ss, e2)) assert.Len(t, ss.partials[1], 1) require.Len(t, ss.matches, 2) @@ -452,11 +459,11 @@ func TestUnconstrainedSequenceMatches(t *testing.T) { Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e1)) - require.False(t, ss.runSequence(e2)) + require.False(t, runSequence(ss, e1)) + require.False(t, runSequence(ss, e2)) assert.Len(t, ss.partials[0], 2) assert.Len(t, ss.partials[1], 0) - require.True(t, ss.runSequence(e3)) + require.True(t, runSequence(ss, e3)) assert.Len(t, ss.partials[1], 1) require.Len(t, ss.matches, 2) @@ -495,7 +502,8 @@ func TestSimpleSequenceDeadline(t *testing.T) { }, Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e1)) + + require.False(t, runSequence(ss, e1)) e2 := &event.Event{ Type: event.CreateFile, @@ -513,8 +521,9 @@ func TestSimpleSequenceDeadline(t *testing.T) { }, Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } + time.Sleep(time.Millisecond * 110) - require.False(t, ss.runSequence(e2)) + require.False(t, runSequence(ss, e2)) require.Equal(t, sequenceInitialState, ss.currentState()) assert.Len(t, ss.partials, 0) @@ -523,17 +532,17 @@ func TestSimpleSequenceDeadline(t *testing.T) { // to the initial state, which means we should // be able to match the sequence if we reinsert // the events - require.False(t, ss.runSequence(e1)) - require.True(t, ss.runSequence(e2)) + require.False(t, runSequence(ss, e1)) + require.True(t, runSequence(ss, e2)) ss.clearLocked() require.Equal(t, sequenceInitialState, ss.currentState()) assert.Len(t, ss.partials, 0) // assert the events again with the delay less than max span - require.False(t, ss.runSequence(e1)) + require.False(t, runSequence(ss, e1)) time.Sleep(time.Millisecond * 85) - require.True(t, ss.runSequence(e2)) + require.True(t, runSequence(ss, e2)) } func TestSequenceMultiLinks(t *testing.T) { @@ -565,7 +574,8 @@ func TestSequenceMultiLinks(t *testing.T) { }, Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e1)) + + require.False(t, runSequence(ss, e1)) e2 := &event.Event{ Type: event.CreateFile, @@ -583,7 +593,8 @@ func TestSequenceMultiLinks(t *testing.T) { }, Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.True(t, ss.runSequence(e2)) + + require.True(t, runSequence(ss, e2)) } func TestComplexSequence(t *testing.T) { @@ -592,10 +603,10 @@ func TestComplexSequence(t *testing.T) { c := &config.FilterConfig{Name: "Phishing dropper outbound communication"} f := filter.New(` sequence - maxspan 1h - |evt.name = 'CreateProcess' and ps.name in ('firefox.exe', 'chrome.exe', 'edge.exe')| by ps.pid - |evt.name = 'CreateFile' and file.operation = 'CREATE' and file.extension = '.exe'| by ps.pid - |evt.name in ('Send', 'Connect')| by ps.pid + maxspan 1h + |evt.name = 'CreateProcess' and ps.name in ('firefox.exe', 'chrome.exe', 'edge.exe')| by ps.pid + |evt.name = 'CreateFile' and file.operation = 'CREATE' and file.extension = '.exe'| by ps.pid + |evt.name in ('Send', 'Connect')| by ps.pid `, &config.Config{EventSource: config.EventSourceConfig{EnableFileIOEvents: true}, Filters: &config.Filters{}}) require.NoError(t, f.Compile()) @@ -619,7 +630,8 @@ func TestComplexSequence(t *testing.T) { }, Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e1)) + + require.False(t, runSequence(ss, e1)) e2 := &event.Event{ Seq: 2, @@ -640,7 +652,8 @@ func TestComplexSequence(t *testing.T) { }, Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e2)) + + require.False(t, runSequence(ss, e2)) assert.Len(t, ss.partials[0], 1) assert.Len(t, ss.partials[1], 1) @@ -665,7 +678,7 @@ func TestComplexSequence(t *testing.T) { } time.Sleep(time.Millisecond * 30) - require.True(t, ss.runSequence(e3)) + require.True(t, runSequence(ss, e3)) time.Sleep(time.Millisecond * 50) @@ -674,10 +687,10 @@ func TestComplexSequence(t *testing.T) { // FSM should transition from terminal to initial state require.Equal(t, sequenceInitialState, ss.currentState()) - require.False(t, ss.runSequence(e1)) - require.False(t, ss.runSequence(e2)) + require.False(t, runSequence(ss, e1)) + require.False(t, runSequence(ss, e2)) time.Sleep(time.Millisecond * 15) - require.True(t, ss.runSequence(e3)) + require.True(t, runSequence(ss, e3)) } func TestSequenceOOO(t *testing.T) { @@ -686,9 +699,9 @@ func TestSequenceOOO(t *testing.T) { c := &config.FilterConfig{Name: "LSASS memory dumping via legitimate or offensive tools"} f := filter.New(` sequence - maxspan 2m - |evt.name = 'OpenProcess' and evt.arg[exe] imatches '?:\\Windows\\System32\\lsass.exe'| by ps.uuid - |evt.name = 'CreateFile' and file.operation = 'CREATE' and file.extension = '.dmp'| by ps.uuid + maxspan 2m + |evt.name = 'OpenProcess' and evt.arg[exe] imatches '?:\\Windows\\System32\\lsass.exe'| by ps.uuid + |evt.name = 'CreateFile' and file.operation = 'CREATE' and file.extension = '.dmp'| by ps.uuid `, &config.Config{EventSource: config.EventSourceConfig{EnableFileIOEvents: true}, Filters: &config.Filters{}}) require.NoError(t, f.Compile()) @@ -711,7 +724,8 @@ func TestSequenceOOO(t *testing.T) { }, Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e1)) + + require.False(t, runSequence(ss, e1)) require.Len(t, ss.partials[1], 1) assert.True(t, ss.partials[1][0].ContainsMeta(event.RuleSequenceOOOKey)) @@ -733,7 +747,7 @@ func TestSequenceOOO(t *testing.T) { Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.True(t, ss.runSequence(e2)) + require.True(t, runSequence(ss, e2)) assert.Len(t, ss.partials[0], 1) assert.False(t, ss.partials[1][0].ContainsMeta(event.RuleSequenceOOOKey)) } @@ -746,9 +760,9 @@ func TestSequenceGC(t *testing.T) { c := &config.FilterConfig{Name: "LSASS memory dumping via legitimate or offensive tools"} f := filter.New(` sequence - by ps.uuid - |evt.name = 'OpenProcess' and evt.arg[exe] imatches '?:\\Windows\\System32\\lsass.exe'| - |evt.name = 'CreateFile' and file.operation = 'CREATE' and file.extension = '.dmp'| + by ps.uuid + |evt.name = 'OpenProcess' and evt.arg[exe] imatches '?:\\Windows\\System32\\lsass.exe'| + |evt.name = 'CreateFile' and file.operation = 'CREATE' and file.extension = '.dmp'| `, &config.Config{EventSource: config.EventSourceConfig{EnableFileIOEvents: true}, Filters: &config.Filters{}}) require.NoError(t, f.Compile()) @@ -772,7 +786,7 @@ func TestSequenceGC(t *testing.T) { Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e)) + require.False(t, runSequence(ss, e)) assert.Len(t, ss.partials[0], 1) time.Sleep(time.Second) @@ -910,13 +924,13 @@ func TestSequenceExpire(t *testing.T) { if evt.IsTerminateProcess() { ss.expire(evt) } else { - ss.runSequence(evt) + runSequence(ss, evt) } } require.Equal(t, tt.wants, ss.inExpired.Load()) require.Len(t, ss.partials, 0) - ss.runSequence(tt.evts[0]) + runSequence(ss, tt.evts[0]) require.False(t, ss.inExpired.Load()) }) } @@ -1010,10 +1024,10 @@ func TestSequenceBoundFields(t *testing.T) { Metadata: map[event.MetadataKey]any{"foo": "bar", "fooz": "barzz"}, } - require.False(t, ss.runSequence(e1)) - require.False(t, ss.runSequence(e2)) - require.False(t, ss.runSequence(e3)) - require.True(t, ss.runSequence(e4)) + require.False(t, runSequence(ss, e1)) + require.False(t, runSequence(ss, e2)) + require.False(t, runSequence(ss, e3)) + require.True(t, runSequence(ss, e4)) } func TestSequenceBoundFieldsWithFunctions(t *testing.T) { @@ -1078,8 +1092,8 @@ func TestSequenceBoundFieldsWithFunctions(t *testing.T) { require.NoError(t, key.SetStringsValue("Notification Packages", []string{"secli", "passwdflt"})) - require.False(t, ss.runSequence(e1)) - require.True(t, ss.runSequence(e2)) + require.False(t, runSequence(ss, e1)) + require.True(t, runSequence(ss, e2)) } func TestIsExpressionEvaluable(t *testing.T) {