Skip to content

Commit d31803a

Browse files
authored
refactor: params extra types are zero values not nil pointers by default (#13)
* refactor: extra types `C` + `R` are never plumbed as `*C` / `*R` * refactor: force use of `pseudo.Constructor.Zero()` instead of `NilPointer()` * feat: `pseudo.PointerTo()` * feat: `params.ExtraPayloadGetter[C,R].PointerFromChainConfig(...) *C` and `Rules => *R` equiv * test: shallow copy of `ChainConfig`/`Rules` includes extras
1 parent 72744ce commit d31803a

File tree

8 files changed

+229
-63
lines changed

8 files changed

+229
-63
lines changed

libevm/hookstest/stub.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type Stub struct {
3232
// Register is a convenience wrapper for registering s as both the
3333
// [params.ChainConfigHooks] and [params.RulesHooks] via [Register].
3434
func (s *Stub) Register(tb testing.TB) {
35-
Register(tb, params.Extras[Stub, Stub]{
35+
Register(tb, params.Extras[*Stub, *Stub]{
3636
NewRules: func(_ *params.ChainConfig, _ *params.Rules, _ *Stub, blockNum *big.Int, isMerge bool, timestamp uint64) *Stub {
3737
return s
3838
},

libevm/pseudo/type.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,27 @@ func Zero[T any]() *Pseudo[T] {
6060
return From[T](x)
6161
}
6262

63+
// PointerTo is equivalent to [From] called with a pointer to the payload
64+
// carried by `t`. It first confirms that the payload is of type `T`.
65+
func PointerTo[T any](t *Type) (*Pseudo[*T], error) {
66+
c, ok := t.val.(*concrete[T])
67+
if !ok {
68+
var want *T
69+
return nil, fmt.Errorf("cannot create *Pseudo[%T] from *Type carrying %T", want, t.val.get())
70+
}
71+
return From(&c.val), nil
72+
}
73+
74+
// MustPointerTo is equivalent to [PointerTo] except that it panics instead of
75+
// returning an error.
76+
func MustPointerTo[T any](t *Type) *Pseudo[*T] {
77+
p, err := PointerTo[T](t)
78+
if err != nil {
79+
panic(err)
80+
}
81+
return p
82+
}
83+
6384
// Interface returns the wrapped value as an `any`, equivalent to
6485
// [reflect.Value.Interface]. Prefer [Value.Get].
6586
func (t *Type) Interface() any { return t.val.get() }

libevm/pseudo/type_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,26 @@ func ExamplePseudo_TypeAndValue() {
7777
_ = typ
7878
_ = val
7979
}
80+
81+
func TestPointer(t *testing.T) {
82+
type carrier struct {
83+
payload int
84+
}
85+
86+
typ, val := From(carrier{42}).TypeAndValue()
87+
88+
t.Run("invalid type", func(t *testing.T) {
89+
_, err := PointerTo[int](typ)
90+
require.Errorf(t, err, "PointerTo[int](%T)", carrier{})
91+
})
92+
93+
t.Run("valid type", func(t *testing.T) {
94+
ptrVal := MustPointerTo[carrier](typ).Value
95+
96+
assert.Equal(t, 42, val.Get().payload, "before setting via pointer")
97+
var ptr *carrier = ptrVal.Get()
98+
ptr.payload = 314159
99+
assert.Equal(t, 314159, val.Get().payload, "after setting via pointer")
100+
})
101+
102+
}

params/config.libevm.go

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,27 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct {
2525
// NewRules, if non-nil is called at the end of [ChainConfig.Rules] with the
2626
// newly created [Rules] and other context from the method call. Its
2727
// returned value will be the extra payload of the [Rules]. If NewRules is
28-
// nil then so too will the [Rules] extra payload be a nil `*R`.
28+
// nil then so too will the [Rules] extra payload be a zero-value `R`.
2929
//
3030
// NewRules MAY modify the [Rules] but MUST NOT modify the [ChainConfig].
31-
NewRules func(_ *ChainConfig, _ *Rules, _ *C, blockNum *big.Int, isMerge bool, timestamp uint64) *R
31+
// TODO(arr4n): add the [Rules] to the return signature to make it clearer
32+
// that the caller can modify the generated Rules.
33+
NewRules func(_ *ChainConfig, _ *Rules, _ C, blockNum *big.Int, isMerge bool, timestamp uint64) R
3234
}
3335

3436
// RegisterExtras registers the types `C` and `R` such that they are carried as
3537
// extra payloads in [ChainConfig] and [Rules] structs, respectively. It is
3638
// expected to be called in an `init()` function and MUST NOT be called more
37-
// than once. Both `C` and `R` MUST be structs.
39+
// than once. Both `C` and `R` MUST be structs or pointers to structs.
3840
//
3941
// After registration, JSON unmarshalling of a [ChainConfig] will create a new
40-
// `*C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
41-
// will populate the "extra" key with the contents of the `*C`. Both the
42+
// `C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
43+
// will populate the "extra" key with the contents of the `C`. Both the
4244
// [json.Marshaler] and [json.Unmarshaler] interfaces are honoured if
4345
// implemented by `C` and/or `R.`
4446
//
4547
// Calls to [ChainConfig.Rules] will call the `NewRules` function of the
46-
// registered [Extras] to create a new `*R`.
48+
// registered [Extras] to create a new `R`.
4749
//
4850
// The payloads can be accessed via the [ExtraPayloadGetter.FromChainConfig] and
4951
// [ExtraPayloadGetter.FromRules] methods of the getter returned by
@@ -54,16 +56,16 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo
5456
if registeredExtras != nil {
5557
panic("re-registration of Extras")
5658
}
57-
mustBeStruct[C]()
58-
mustBeStruct[R]()
59+
mustBeStructOrPointerToOne[C]()
60+
mustBeStructOrPointerToOne[R]()
5961

6062
getter := e.getter()
6163
registeredExtras = &extraConstructors{
62-
chainConfig: pseudo.NewConstructor[C](),
63-
rules: pseudo.NewConstructor[R](),
64-
reuseJSONRoot: e.ReuseJSONRoot,
65-
newForRules: e.newForRules,
66-
getter: getter,
64+
newChainConfig: pseudo.NewConstructor[C]().Zero,
65+
newRules: pseudo.NewConstructor[R]().Zero,
66+
reuseJSONRoot: e.ReuseJSONRoot,
67+
newForRules: e.newForRules,
68+
getter: getter,
6769
}
6870
return getter
6971
}
@@ -95,9 +97,9 @@ func TestOnlyClearRegisteredExtras() {
9597
var registeredExtras *extraConstructors
9698

9799
type extraConstructors struct {
98-
chainConfig, rules pseudo.Constructor
99-
reuseJSONRoot bool
100-
newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type
100+
newChainConfig, newRules func() *pseudo.Type
101+
reuseJSONRoot bool
102+
newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type
101103
// use top-level hooksFrom<X>() functions instead of these as they handle
102104
// instances where no [Extras] were registered.
103105
getter interface {
@@ -108,27 +110,34 @@ type extraConstructors struct {
108110

109111
func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
110112
if e.NewRules == nil {
111-
return registeredExtras.rules.NilPointer()
113+
return registeredExtras.newRules()
112114
}
113115
rExtra := e.NewRules(c, r, e.getter().FromChainConfig(c), blockNum, isMerge, timestamp)
114116
return pseudo.From(rExtra).Type
115117
}
116118

117119
func (*Extras[C, R]) getter() (g ExtraPayloadGetter[C, R]) { return }
118120

119-
// mustBeStruct panics if `T` isn't a struct.
120-
func mustBeStruct[T any]() {
121+
// mustBeStructOrPointerToOne panics if `T` isn't a struct or a *struct.
122+
func mustBeStructOrPointerToOne[T any]() {
121123
var x T
122-
if k := reflect.TypeOf(x).Kind(); k != reflect.Struct {
123-
panic(notStructMessage[T]())
124+
switch t := reflect.TypeOf(x); t.Kind() {
125+
case reflect.Struct:
126+
return
127+
case reflect.Pointer:
128+
if t.Elem().Kind() == reflect.Struct {
129+
return
130+
}
124131
}
132+
panic(notStructMessage[T]())
125133
}
126134

127-
// notStructMessage returns the message with which [mustBeStruct] might panic.
128-
// It exists to avoid change-detector tests should the message contents change.
135+
// notStructMessage returns the message with which [mustBeStructOrPointerToOne]
136+
// might panic. It exists to avoid change-detector tests should the message
137+
// contents change.
129138
func notStructMessage[T any]() string {
130139
var x T
131-
return fmt.Sprintf("%T is not a struct", x)
140+
return fmt.Sprintf("%T is not a struct nor a pointer to a struct", x)
132141
}
133142

134143
// An ExtraPayloadGettter provides strongly typed access to the extra payloads
@@ -139,33 +148,37 @@ type ExtraPayloadGetter[C ChainConfigHooks, R RulesHooks] struct {
139148
}
140149

141150
// FromChainConfig returns the ChainConfig's extra payload.
142-
func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) *C {
143-
return pseudo.MustNewValue[*C](c.extraPayload()).Get()
151+
func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) C {
152+
return pseudo.MustNewValue[C](c.extraPayload()).Get()
153+
}
154+
155+
// PointerFromChainConfig returns a pointer to the ChainConfig's extra payload.
156+
// This is guaranteed to be non-nil.
157+
func (ExtraPayloadGetter[C, R]) PointerFromChainConfig(c *ChainConfig) *C {
158+
return pseudo.MustPointerTo[C](c.extraPayload()).Value.Get()
144159
}
145160

146161
// hooksFromChainConfig is equivalent to FromChainConfig(), but returns an
147162
// interface instead of the concrete type implementing it; this allows it to be
148-
// used in non-generic code. If the concrete-type value is nil (typically
149-
// because no [Extras] were registered) a [noopHooks] is returned so it can be
150-
// used without nil checks.
163+
// used in non-generic code.
151164
func (e ExtraPayloadGetter[C, R]) hooksFromChainConfig(c *ChainConfig) ChainConfigHooks {
152-
if h := e.FromChainConfig(c); h != nil {
153-
return *h
154-
}
155-
return NOOPHooks{}
165+
return e.FromChainConfig(c)
156166
}
157167

158168
// FromRules returns the Rules' extra payload.
159-
func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) *R {
160-
return pseudo.MustNewValue[*R](r.extraPayload()).Get()
169+
func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) R {
170+
return pseudo.MustNewValue[R](r.extraPayload()).Get()
171+
}
172+
173+
// PointerFromRules returns a pointer to the Rules's extra payload. This is
174+
// guaranteed to be non-nil.
175+
func (ExtraPayloadGetter[C, R]) PointerFromRules(r *Rules) *R {
176+
return pseudo.MustPointerTo[R](r.extraPayload()).Value.Get()
161177
}
162178

163179
// hooksFromRules is the [RulesHooks] equivalent of hooksFromChainConfig().
164180
func (e ExtraPayloadGetter[C, R]) hooksFromRules(r *Rules) RulesHooks {
165-
if h := e.FromRules(r); h != nil {
166-
return *h
167-
}
168-
return NOOPHooks{}
181+
return e.FromRules(r)
169182
}
170183

171184
// addRulesExtra is called at the end of [ChainConfig.Rules]; it exists to
@@ -189,7 +202,7 @@ func (c *ChainConfig) extraPayload() *pseudo.Type {
189202
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c))
190203
}
191204
if c.extra == nil {
192-
c.extra = registeredExtras.chainConfig.NilPointer()
205+
c.extra = registeredExtras.newChainConfig()
193206
}
194207
return c.extra
195208
}
@@ -201,7 +214,7 @@ func (r *Rules) extraPayload() *pseudo.Type {
201214
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r))
202215
}
203216
if r.extra == nil {
204-
r.extra = registeredExtras.rules.NilPointer()
217+
r.extra = registeredExtras.newRules()
205218
}
206219
return r.extra
207220
}

params/config.libevm_test.go

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,36 @@ func TestRegisterExtras(t *testing.T) {
5050
name: "Rules payload copied from ChainConfig payload",
5151
register: func() {
5252
RegisterExtras(Extras[ccExtraA, rulesExtraA]{
53-
NewRules: func(cc *ChainConfig, r *Rules, ex *ccExtraA, _ *big.Int, _ bool, _ uint64) *rulesExtraA {
54-
return &rulesExtraA{
53+
NewRules: func(cc *ChainConfig, r *Rules, ex ccExtraA, _ *big.Int, _ bool, _ uint64) rulesExtraA {
54+
return rulesExtraA{
5555
A: ex.A,
5656
}
5757
},
5858
})
5959
},
60-
ccExtra: pseudo.From(&ccExtraA{
60+
ccExtra: pseudo.From(ccExtraA{
6161
A: "hello",
6262
}).Type,
63-
wantRulesExtra: &rulesExtraA{
63+
wantRulesExtra: rulesExtraA{
6464
A: "hello",
6565
},
6666
},
6767
{
68-
name: "no NewForRules() function results in typed but nil pointer",
68+
name: "no NewForRules() function results in zero value",
6969
register: func() {
7070
RegisterExtras(Extras[ccExtraB, rulesExtraB]{})
7171
},
72-
ccExtra: pseudo.From(&ccExtraB{
72+
ccExtra: pseudo.From(ccExtraB{
73+
B: "world",
74+
}).Type,
75+
wantRulesExtra: rulesExtraB{},
76+
},
77+
{
78+
name: "no NewForRules() function results in nil pointer",
79+
register: func() {
80+
RegisterExtras(Extras[ccExtraB, *rulesExtraB]{})
81+
},
82+
ccExtra: pseudo.From(ccExtraB{
7383
B: "world",
7484
}).Type,
7585
wantRulesExtra: (*rulesExtraB)(nil),
@@ -79,10 +89,10 @@ func TestRegisterExtras(t *testing.T) {
7989
register: func() {
8090
RegisterExtras(Extras[rawJSON, struct{ RulesHooks }]{})
8191
},
82-
ccExtra: pseudo.From(&rawJSON{
92+
ccExtra: pseudo.From(rawJSON{
8393
RawMessage: []byte(`"hello, world"`),
8494
}).Type,
85-
wantRulesExtra: (*struct{ RulesHooks })(nil),
95+
wantRulesExtra: struct{ RulesHooks }{},
8696
},
8797
}
8898

@@ -111,6 +121,75 @@ func TestRegisterExtras(t *testing.T) {
111121
}
112122
}
113123

124+
func TestModificationOfZeroExtras(t *testing.T) {
125+
type (
126+
ccExtra struct {
127+
X int
128+
NOOPHooks
129+
}
130+
rulesExtra struct {
131+
X int
132+
NOOPHooks
133+
}
134+
)
135+
136+
TestOnlyClearRegisteredExtras()
137+
t.Cleanup(TestOnlyClearRegisteredExtras)
138+
getter := RegisterExtras(Extras[ccExtra, rulesExtra]{})
139+
140+
config := new(ChainConfig)
141+
rules := new(Rules)
142+
// These assertion helpers are defined before any modifications so that the
143+
// closure is demonstrably over the original zero values.
144+
assertChainConfigExtra := func(t *testing.T, want ccExtra, msg string) {
145+
t.Helper()
146+
assert.Equalf(t, want, getter.FromChainConfig(config), "%T: "+msg, &config)
147+
}
148+
assertRulesExtra := func(t *testing.T, want rulesExtra, msg string) {
149+
t.Helper()
150+
assert.Equalf(t, want, getter.FromRules(rules), "%T: "+msg, &rules)
151+
}
152+
153+
assertChainConfigExtra(t, ccExtra{}, "zero value")
154+
assertRulesExtra(t, rulesExtra{}, "zero value")
155+
156+
const answer = 42
157+
getter.PointerFromChainConfig(config).X = answer
158+
assertChainConfigExtra(t, ccExtra{X: answer}, "after setting via pointer field")
159+
160+
const pi = 314159
161+
getter.PointerFromRules(rules).X = pi
162+
assertRulesExtra(t, rulesExtra{X: pi}, "after setting via pointer field")
163+
164+
ccReplace := ccExtra{X: 142857}
165+
*getter.PointerFromChainConfig(config) = ccReplace
166+
assertChainConfigExtra(t, ccReplace, "after replacement of entire extra via `*pointer = x`")
167+
168+
rulesReplace := rulesExtra{X: 18101986}
169+
*getter.PointerFromRules(rules) = rulesReplace
170+
assertRulesExtra(t, rulesReplace, "after replacement of entire extra via `*pointer = x`")
171+
172+
if t.Failed() {
173+
// The test of shallow copying is now guaranteed to fail.
174+
return
175+
}
176+
t.Run("shallow copy", func(t *testing.T) {
177+
ccCopy := *config
178+
rCopy := *rules
179+
180+
assert.Equal(t, getter.FromChainConfig(&ccCopy), ccReplace, "ChainConfig extras copied")
181+
assert.Equal(t, getter.FromRules(&rCopy), rulesReplace, "Rules extras copied")
182+
183+
const seqUp = 123456789
184+
getter.PointerFromChainConfig(&ccCopy).X = seqUp
185+
assertChainConfigExtra(t, ccExtra{X: seqUp}, "original changed because copy only shallow")
186+
187+
const seqDown = 987654321
188+
getter.PointerFromRules(&rCopy).X = seqDown
189+
assertRulesExtra(t, rulesExtra{X: seqDown}, "original changed because copy only shallow")
190+
})
191+
}
192+
114193
func TestExtrasPanic(t *testing.T) {
115194
TestOnlyClearRegisteredExtras()
116195
defer TestOnlyClearRegisteredExtras()
@@ -131,7 +210,7 @@ func TestExtrasPanic(t *testing.T) {
131210

132211
assertPanics(
133212
t, func() {
134-
mustBeStruct[int]()
213+
mustBeStructOrPointerToOne[int]()
135214
},
136215
notStructMessage[int](),
137216
)

0 commit comments

Comments
 (0)