Skip to content

Commit e580030

Browse files
authored
(chore): rework NewContext and NewContextWithBudget (#90)
Part of APPSEC-52238 - Add an error return value - Make NewContext and NewContextWithBudget handle methods
1 parent 5da8da6 commit e580030

File tree

4 files changed

+92
-68
lines changed

4 files changed

+92
-68
lines changed

context.go

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
package waf
77

88
import (
9-
"github.com/DataDog/go-libddwaf/v2/timer"
109
"sync"
1110
"time"
1211

1312
"github.com/DataDog/go-libddwaf/v2/errors"
1413
"github.com/DataDog/go-libddwaf/v2/internal/bindings"
1514
"github.com/DataDog/go-libddwaf/v2/internal/unsafe"
15+
"github.com/DataDog/go-libddwaf/v2/timer"
1616

1717
"sync/atomic"
1818
)
@@ -42,40 +42,6 @@ type Context struct {
4242
truncations map[TruncationReason][]int
4343
}
4444

45-
// NewContext returns a new WAF context of to the given WAF handle.
46-
// A nil value is returned when the WAF handle was released or when the
47-
// WAF context couldn't be created.
48-
// handle. A nil value is returned when the WAF handle can no longer be used
49-
// or the WAF context couldn't be created.
50-
func NewContext(handle *Handle) *Context {
51-
return NewContextWithBudget(handle, timer.UnlimitedBudget)
52-
}
53-
54-
// NewContextWithBudget returns a new WAF context of to the given WAF handle.
55-
// A nil value is returned when the WAF handle was released or when the
56-
// WAF context couldn't be created.
57-
// handle. A nil value is returned when the WAF handle can no longer be used
58-
// or the WAF context couldn't be created.
59-
func NewContextWithBudget(handle *Handle, budget time.Duration) *Context {
60-
// Handle has been released
61-
if !handle.retain() {
62-
return nil
63-
}
64-
65-
cContext := wafLib.WafContextInit(handle.cHandle)
66-
if cContext == 0 {
67-
handle.release() // We couldn't get a context, so we no longer have an implicit reference to the Handle in it...
68-
return nil
69-
}
70-
71-
timer, err := timer.NewTreeTimer(timer.WithBudget(budget), timer.WithComponents(wafRunTag))
72-
if err != nil {
73-
return nil
74-
}
75-
76-
return &Context{handle: handle, cContext: cContext, timer: timer, metrics: metricsStore{data: make(map[string]time.Duration, 5)}}
77-
}
78-
7945
// RunAddressData provides address data to the Context.Run method. If a given key is present in both
8046
// RunAddressData.Persistent and RunAddressData.Ephemeral, the value from RunAddressData.Persistent will take precedence.
8147
type RunAddressData struct {

encoder_decoder_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ func wafTest(t *testing.T, obj *bindings.WafObject) {
2828
waf, err := newDefaultHandle(newArachniTestRule([]ruleInput{{Address: "my.input"}, {Address: "my.other.input"}}, nil))
2929
require.NoError(t, err)
3030
defer waf.Close()
31-
wafCtx := NewContext(waf)
31+
wafCtx, err := waf.NewContext()
32+
require.NoError(t, err)
3233
require.NotNil(t, wafCtx)
3334
defer wafCtx.Close()
3435
_, err = wafCtx.Run(RunAddressData{

handle.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ package waf
88
import (
99
"errors"
1010
"fmt"
11+
"time"
1112

1213
wafErrors "github.com/DataDog/go-libddwaf/v2/errors"
1314
"github.com/DataDog/go-libddwaf/v2/internal/bindings"
1415
"github.com/DataDog/go-libddwaf/v2/internal/unsafe"
16+
"github.com/DataDog/go-libddwaf/v2/timer"
1517

1618
"sync/atomic"
1719
)
@@ -104,6 +106,36 @@ func NewHandle(rules any, keyObfuscatorRegex string, valueObfuscatorRegex string
104106
return handle, nil
105107
}
106108

109+
// NewContext returns a new WAF context for the given WAF handle.
110+
// A nil value is returned when the WAF handle was released or when the
111+
// WAF context couldn't be created.
112+
func (handle *Handle) NewContext() (*Context, error) {
113+
return handle.NewContextWithBudget(timer.UnlimitedBudget)
114+
}
115+
116+
// NewContextWithBudget returns a new WAF context for the given WAF handle.
117+
// A nil value is returned when the WAF handle was released or when the
118+
// WAF context couldn't be created.
119+
func (handle *Handle) NewContextWithBudget(budget time.Duration) (*Context, error) {
120+
// Handle has been released
121+
if !handle.retain() {
122+
return nil, fmt.Errorf("handle was released")
123+
}
124+
125+
cContext := wafLib.WafContextInit(handle.cHandle)
126+
if cContext == 0 {
127+
handle.release() // We couldn't get a context, so we no longer have an implicit reference to the Handle in it...
128+
return nil, fmt.Errorf("could not get C context")
129+
}
130+
131+
timer, err := timer.NewTreeTimer(timer.WithBudget(budget), timer.WithComponents(wafRunTag))
132+
if err != nil {
133+
return nil, err
134+
}
135+
136+
return &Context{handle: handle, cContext: cContext, timer: timer, metrics: metricsStore{data: make(map[string]time.Duration, 5)}}, nil
137+
}
138+
107139
// Diagnostics returns the rules initialization metrics for the current WAF handle
108140
func (handle *Handle) Diagnostics() Diagnostics {
109141
return handle.diagnostics

waf_test.go

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ func TestUpdateWAF(t *testing.T) {
274274
require.NotNil(t, waf)
275275
defer waf.Close()
276276

277-
wafCtx := NewContext(waf)
277+
wafCtx, err := waf.NewContext()
278+
require.NoError(t, err)
278279
defer wafCtx.Close()
279280

280281
// Matches
@@ -295,7 +296,8 @@ func TestUpdateWAF(t *testing.T) {
295296
require.NotNil(t, waf2)
296297
defer waf2.Close()
297298

298-
wafCtx2 := NewContext(waf2)
299+
wafCtx2, err := waf2.NewContext()
300+
require.NoError(t, err)
299301
defer wafCtx2.Close()
300302

301303
// Matches & Block
@@ -369,11 +371,12 @@ func TestTimeout(t *testing.T) {
369371
}
370372

371373
t.Run("not-empty-metrics-match", func(t *testing.T) {
372-
context := NewContextWithBudget(waf, time.Hour)
374+
context, err := waf.NewContextWithBudget(time.Hour)
375+
require.NoError(t, err)
373376
require.NotNil(t, context)
374377
defer context.Close()
375378

376-
_, err := context.Run(RunAddressData{Persistent: normalValue, Ephemeral: normalValue}, 0)
379+
_, err = context.Run(RunAddressData{Persistent: normalValue, Ephemeral: normalValue}, 0)
377380
require.NoError(t, err)
378381
require.NotEmpty(t, context.Stats())
379382
require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.decode"])
@@ -383,11 +386,12 @@ func TestTimeout(t *testing.T) {
383386
})
384387

385388
t.Run("not-empty-metrics-no-match", func(t *testing.T) {
386-
context := NewContextWithBudget(waf, time.Hour)
389+
context, err := waf.NewContextWithBudget(time.Hour)
390+
require.NoError(t, err)
387391
require.NotNil(t, context)
388392
defer context.Close()
389393

390-
_, err := context.Run(RunAddressData{Persistent: map[string]any{"my.input": "curl/7.88"}}, 0)
394+
_, err = context.Run(RunAddressData{Persistent: map[string]any{"my.input": "curl/7.88"}}, 0)
391395
require.NoError(t, err)
392396
require.NotEmpty(t, context.Stats())
393397
require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.decode"])
@@ -397,34 +401,35 @@ func TestTimeout(t *testing.T) {
397401
})
398402

399403
t.Run("timeout-persistent-encoder", func(t *testing.T) {
400-
context := NewContextWithBudget(waf, time.Millisecond)
404+
context, err := waf.NewContextWithBudget(time.Millisecond)
405+
require.NoError(t, err)
401406
require.NotNil(t, context)
402407
defer context.Close()
403408

404-
_, err := context.Run(RunAddressData{Persistent: largeValue}, 0)
409+
_, err = context.Run(RunAddressData{Persistent: largeValue}, 0)
405410
require.Equal(t, errors.ErrTimeout, err)
406411
require.GreaterOrEqual(t, context.Stats().Timers["_dd.appsec.waf.duration_ext"], time.Millisecond)
407412
require.GreaterOrEqual(t, context.Stats().Timers["_dd.appsec.waf.encode"], time.Millisecond)
408413
})
409414

410415
t.Run("timeout-ephemeral-encoder", func(t *testing.T) {
411-
context := NewContextWithBudget(waf, time.Millisecond)
416+
context, err := waf.NewContextWithBudget(time.Millisecond)
417+
require.NoError(t, err)
412418
require.NotNil(t, context)
413419
defer context.Close()
414420

415-
_, err := context.Run(RunAddressData{Ephemeral: largeValue}, 0)
421+
_, err = context.Run(RunAddressData{Ephemeral: largeValue}, 0)
416422
require.Equal(t, errors.ErrTimeout, err)
417423
require.GreaterOrEqual(t, context.Stats().Timers["_dd.appsec.waf.duration_ext"], time.Millisecond)
418424
require.GreaterOrEqual(t, context.Stats().Timers["_dd.appsec.waf.encode"], time.Millisecond)
419425
})
420426

421427
t.Run("many-runs", func(t *testing.T) {
422-
context := NewContextWithBudget(waf, time.Millisecond)
428+
context, err := waf.NewContextWithBudget(time.Millisecond)
429+
require.NoError(t, err)
423430
require.NotNil(t, context)
424431
defer context.Close()
425432

426-
var err error
427-
428433
for i := 0; i < 1000 && err != errors.ErrTimeout; i++ {
429434
_, err = context.Run(RunAddressData{Persistent: normalValue}, 0)
430435
}
@@ -441,7 +446,8 @@ func TestMatching(t *testing.T) {
441446

442447
require.Equal(t, []string{"my.input"}, waf.Addresses())
443448

444-
wafCtx := NewContext(waf)
449+
wafCtx, err := waf.NewContext()
450+
require.NoError(t, err)
445451
require.NotNil(t, wafCtx)
446452

447453
// Not matching because the address value doesn't match the rule
@@ -496,7 +502,9 @@ func TestMatching(t *testing.T) {
496502
wafCtx.Close()
497503
waf.Close()
498504
// Using the WAF instance after it was closed leads to a nil WAF context
499-
require.Nil(t, NewContext(waf))
505+
ctx, err := waf.NewContext()
506+
require.Nil(t, ctx)
507+
require.Error(t, err)
500508
}
501509

502510
func TestMatchingEphemeralAndPersistent(t *testing.T) {
@@ -505,7 +513,8 @@ func TestMatchingEphemeralAndPersistent(t *testing.T) {
505513
require.NoError(t, err)
506514
defer waf.Close()
507515

508-
wafCtx := NewContext(waf)
516+
wafCtx, err := waf.NewContext()
517+
require.NoError(t, err)
509518
require.NotNil(t, wafCtx)
510519
defer wafCtx.Close()
511520

@@ -557,7 +566,8 @@ func TestMatchingEphemeral(t *testing.T) {
557566
sort.Strings(addrs)
558567
require.Equal(t, []string{input1, input2}, addrs)
559568

560-
wafCtx := NewContext(waf)
569+
wafCtx, err := waf.NewContext()
570+
require.NoError(t, err)
561571
require.NotNil(t, wafCtx)
562572

563573
// Not matching because the address value doesn't match the rule
@@ -614,7 +624,9 @@ func TestMatchingEphemeral(t *testing.T) {
614624
wafCtx.Close()
615625
waf.Close()
616626
// Using the WAF instance after it was closed leads to a nil WAF context
617-
require.Nil(t, NewContext(waf))
627+
ctx, err := waf.NewContext()
628+
require.Nil(t, ctx)
629+
require.Error(t, err)
618630
}
619631

620632
func TestMatchingEphemeralOnly(t *testing.T) {
@@ -631,7 +643,8 @@ func TestMatchingEphemeralOnly(t *testing.T) {
631643
sort.Strings(addrs)
632644
require.Equal(t, []string{input1, input2}, addrs)
633645

634-
wafCtx := NewContext(waf)
646+
wafCtx, err := waf.NewContext()
647+
require.NoError(t, err)
635648
require.NotNil(t, wafCtx)
636649

637650
// Not matching because the address value doesn't match the rule
@@ -672,7 +685,9 @@ func TestMatchingEphemeralOnly(t *testing.T) {
672685
wafCtx.Close()
673686
waf.Close()
674687
// Using the WAF instance after it was closed leads to a nil WAF context
675-
require.Nil(t, NewContext(waf))
688+
ctx, err := waf.NewContext()
689+
require.Nil(t, ctx)
690+
require.Error(t, err)
676691
}
677692

678693
func TestActions(t *testing.T) {
@@ -684,7 +699,8 @@ func TestActions(t *testing.T) {
684699
require.NotNil(t, waf)
685700
defer waf.Close()
686701

687-
wafCtx := NewContext(waf)
702+
wafCtx, err := waf.NewContext()
703+
require.NoError(t, err)
688704
require.NotNil(t, wafCtx)
689705
defer wafCtx.Close()
690706

@@ -727,7 +743,8 @@ func TestConcurrency(t *testing.T) {
727743
require.NoError(t, err)
728744
defer waf.Close()
729745

730-
wafCtx := NewContext(waf)
746+
wafCtx, err := waf.NewContext()
747+
require.NoError(t, err)
731748
defer wafCtx.Close()
732749

733750
// User agents that won't match the rule so that it doesn't get pruned.
@@ -821,7 +838,8 @@ func TestConcurrency(t *testing.T) {
821838
startBarrier.Wait() // Sync the starts of the goroutines
822839
defer stopBarrier.Done() // Signal we are done when returning
823840

824-
wafCtx := NewContext(waf)
841+
wafCtx, err := waf.NewContext()
842+
require.NoError(t, err)
825843
defer wafCtx.Close()
826844

827845
for c := 0; c < nbRun; c++ {
@@ -890,8 +908,8 @@ func TestConcurrency(t *testing.T) {
890908
startBarrier.Wait() // Sync the starts of the goroutines
891909
defer stopBarrier.Done() // Signal we are done when returning
892910

893-
wafCtx := NewContext(waf)
894-
if wafCtx == nil {
911+
wafCtx, err := waf.NewContext()
912+
if wafCtx == nil || err != nil {
895913
return
896914
}
897915
wafCtx.Close()
@@ -923,7 +941,8 @@ func TestConcurrency(t *testing.T) {
923941
waf, err := newDefaultHandle(testArachniRule)
924942
require.NoError(t, err)
925943

926-
wafCtx := NewContext(waf)
944+
wafCtx, err := waf.NewContext()
945+
require.NoError(t, err)
927946
require.NotNil(t, wafCtx)
928947

929948
var startBarrier, stopBarrier sync.WaitGroup
@@ -1087,7 +1106,8 @@ func TestMetrics(t *testing.T) {
10871106
})
10881107

10891108
t.Run("RunDuration", func(t *testing.T) {
1090-
wafCtx := NewContext(waf)
1109+
wafCtx, err := waf.NewContext()
1110+
require.NoError(t, err)
10911111
require.NotNil(t, wafCtx)
10921112
defer wafCtx.Close()
10931113
// Craft matching data to force work on the WAF
@@ -1113,7 +1133,8 @@ func TestMetrics(t *testing.T) {
11131133
})
11141134

11151135
t.Run("Timeouts", func(t *testing.T) {
1116-
wafCtx := NewContextWithBudget(waf, time.Nanosecond)
1136+
wafCtx, err := waf.NewContextWithBudget(time.Nanosecond)
1137+
require.NoError(t, err)
11171138
require.NotNil(t, wafCtx)
11181139
defer wafCtx.Close()
11191140
// Craft matching data to force work on the WAF
@@ -1138,7 +1159,8 @@ func TestObfuscatorConfig(t *testing.T) {
11381159
waf, err := NewHandle(rule, "key", "")
11391160
require.NoError(t, err)
11401161
defer waf.Close()
1141-
wafCtx := NewContext(waf)
1162+
wafCtx, err := waf.NewContext()
1163+
require.NoError(t, err)
11421164
require.NotNil(t, wafCtx)
11431165
defer wafCtx.Close()
11441166
data := map[string]interface{}{
@@ -1160,7 +1182,8 @@ func TestObfuscatorConfig(t *testing.T) {
11601182
waf, err := NewHandle(rule, "", "sensitive")
11611183
require.NoError(t, err)
11621184
defer waf.Close()
1163-
wafCtx := NewContext(waf)
1185+
wafCtx, err := waf.NewContext()
1186+
require.NoError(t, err)
11641187
require.NotNil(t, wafCtx)
11651188
defer wafCtx.Close()
11661189
data := map[string]interface{}{
@@ -1182,7 +1205,8 @@ func TestObfuscatorConfig(t *testing.T) {
11821205
waf, err := NewHandle(rule, "", "")
11831206
require.NoError(t, err)
11841207
defer waf.Close()
1185-
wafCtx := NewContext(waf)
1208+
wafCtx, err := waf.NewContext()
1209+
require.NoError(t, err)
11861210
require.NotNil(t, wafCtx)
11871211
defer wafCtx.Close()
11881212
data := map[string]interface{}{
@@ -1206,7 +1230,8 @@ func TestTruncationInformation(t *testing.T) {
12061230
require.NoError(t, err)
12071231
defer waf.Close()
12081232

1209-
ctx := NewContext(waf)
1233+
ctx, err := waf.NewContext()
1234+
require.NoError(t, err)
12101235
defer ctx.Close()
12111236

12121237
extra := rand.Intn(10) + 1 // Random int between 1 and 10

0 commit comments

Comments
 (0)