Skip to content

Commit 1ad168e

Browse files
committed
refactor: extra types C + R are never plumbed as *C / *R
1 parent 72744ce commit 1ad168e

File tree

6 files changed

+96
-55
lines changed

6 files changed

+96
-55
lines changed

libevm/hookstest/stub.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type Stub struct {
3232
// Register is a convenience wrapper for registering s as both the
3333
// [params.ChainConfigHooks] and [params.RulesHooks] via [Register].
3434
func (s *Stub) Register(tb testing.TB) {
35-
Register(tb, params.Extras[Stub, Stub]{
35+
Register(tb, params.Extras[*Stub, *Stub]{
3636
NewRules: func(_ *params.ChainConfig, _ *params.Rules, _ *Stub, blockNum *big.Int, isMerge bool, timestamp uint64) *Stub {
3737
return s
3838
},

params/config.libevm.go

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,27 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct {
2525
// NewRules, if non-nil is called at the end of [ChainConfig.Rules] with the
2626
// newly created [Rules] and other context from the method call. Its
2727
// returned value will be the extra payload of the [Rules]. If NewRules is
28-
// nil then so too will the [Rules] extra payload be a nil `*R`.
28+
// nil then so too will the [Rules] extra payload be a zero-value `R`.
2929
//
3030
// NewRules MAY modify the [Rules] but MUST NOT modify the [ChainConfig].
31-
NewRules func(_ *ChainConfig, _ *Rules, _ *C, blockNum *big.Int, isMerge bool, timestamp uint64) *R
31+
// TODO(arr4n): add the [Rules] to the return signature to make it clearer
32+
// that the caller can modify the generated Rules.
33+
NewRules func(_ *ChainConfig, _ *Rules, _ C, blockNum *big.Int, isMerge bool, timestamp uint64) R
3234
}
3335

3436
// RegisterExtras registers the types `C` and `R` such that they are carried as
3537
// extra payloads in [ChainConfig] and [Rules] structs, respectively. It is
3638
// expected to be called in an `init()` function and MUST NOT be called more
37-
// than once. Both `C` and `R` MUST be structs.
39+
// than once. Both `C` and `R` MUST be structs or pointers to structs.
3840
//
3941
// After registration, JSON unmarshalling of a [ChainConfig] will create a new
40-
// `*C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
41-
// will populate the "extra" key with the contents of the `*C`. Both the
42+
// `C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
43+
// will populate the "extra" key with the contents of the `C`. Both the
4244
// [json.Marshaler] and [json.Unmarshaler] interfaces are honoured if
4345
// implemented by `C` and/or `R.`
4446
//
4547
// Calls to [ChainConfig.Rules] will call the `NewRules` function of the
46-
// registered [Extras] to create a new `*R`.
48+
// registered [Extras] to create a new `R`.
4749
//
4850
// The payloads can be accessed via the [ExtraPayloadGetter.FromChainConfig] and
4951
// [ExtraPayloadGetter.FromRules] methods of the getter returned by
@@ -54,8 +56,8 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo
5456
if registeredExtras != nil {
5557
panic("re-registration of Extras")
5658
}
57-
mustBeStruct[C]()
58-
mustBeStruct[R]()
59+
mustBeStructOrPointerToOne[C]()
60+
mustBeStructOrPointerToOne[R]()
5961

6062
getter := e.getter()
6163
registeredExtras = &extraConstructors{
@@ -108,27 +110,34 @@ type extraConstructors struct {
108110

109111
func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
110112
if e.NewRules == nil {
111-
return registeredExtras.rules.NilPointer()
113+
return registeredExtras.rules.Zero()
112114
}
113115
rExtra := e.NewRules(c, r, e.getter().FromChainConfig(c), blockNum, isMerge, timestamp)
114116
return pseudo.From(rExtra).Type
115117
}
116118

117119
func (*Extras[C, R]) getter() (g ExtraPayloadGetter[C, R]) { return }
118120

119-
// mustBeStruct panics if `T` isn't a struct.
120-
func mustBeStruct[T any]() {
121+
// mustBeStructOrPointerToOne panics if `T` isn't a struct or a *struct.
122+
func mustBeStructOrPointerToOne[T any]() {
121123
var x T
122-
if k := reflect.TypeOf(x).Kind(); k != reflect.Struct {
123-
panic(notStructMessage[T]())
124+
switch t := reflect.TypeOf(x); t.Kind() {
125+
case reflect.Struct:
126+
return
127+
case reflect.Pointer:
128+
if t.Elem().Kind() == reflect.Struct {
129+
return
130+
}
124131
}
132+
panic(notStructMessage[T]())
125133
}
126134

127-
// notStructMessage returns the message with which [mustBeStruct] might panic.
128-
// It exists to avoid change-detector tests should the message contents change.
135+
// notStructMessage returns the message with which [mustBeStructOrPointerToOne]
136+
// might panic. It exists to avoid change-detector tests should the message
137+
// contents change.
129138
func notStructMessage[T any]() string {
130139
var x T
131-
return fmt.Sprintf("%T is not a struct", x)
140+
return fmt.Sprintf("%T is not a struct nor a pointer to a struct", x)
132141
}
133142

134143
// An ExtraPayloadGettter provides strongly typed access to the extra payloads
@@ -139,33 +148,25 @@ type ExtraPayloadGetter[C ChainConfigHooks, R RulesHooks] struct {
139148
}
140149

141150
// FromChainConfig returns the ChainConfig's extra payload.
142-
func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) *C {
143-
return pseudo.MustNewValue[*C](c.extraPayload()).Get()
151+
func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) C {
152+
return pseudo.MustNewValue[C](c.extraPayload()).Get()
144153
}
145154

146155
// hooksFromChainConfig is equivalent to FromChainConfig(), but returns an
147156
// interface instead of the concrete type implementing it; this allows it to be
148-
// used in non-generic code. If the concrete-type value is nil (typically
149-
// because no [Extras] were registered) a [noopHooks] is returned so it can be
150-
// used without nil checks.
157+
// used in non-generic code.
151158
func (e ExtraPayloadGetter[C, R]) hooksFromChainConfig(c *ChainConfig) ChainConfigHooks {
152-
if h := e.FromChainConfig(c); h != nil {
153-
return *h
154-
}
155-
return NOOPHooks{}
159+
return e.FromChainConfig(c)
156160
}
157161

158162
// FromRules returns the Rules' extra payload.
159-
func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) *R {
160-
return pseudo.MustNewValue[*R](r.extraPayload()).Get()
163+
func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) R {
164+
return pseudo.MustNewValue[R](r.extraPayload()).Get()
161165
}
162166

163167
// hooksFromRules is the [RulesHooks] equivalent of hooksFromChainConfig().
164168
func (e ExtraPayloadGetter[C, R]) hooksFromRules(r *Rules) RulesHooks {
165-
if h := e.FromRules(r); h != nil {
166-
return *h
167-
}
168-
return NOOPHooks{}
169+
return e.FromRules(r)
169170
}
170171

171172
// addRulesExtra is called at the end of [ChainConfig.Rules]; it exists to
@@ -189,7 +190,7 @@ func (c *ChainConfig) extraPayload() *pseudo.Type {
189190
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c))
190191
}
191192
if c.extra == nil {
192-
c.extra = registeredExtras.chainConfig.NilPointer()
193+
c.extra = registeredExtras.chainConfig.Zero()
193194
}
194195
return c.extra
195196
}
@@ -201,7 +202,7 @@ func (r *Rules) extraPayload() *pseudo.Type {
201202
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r))
202203
}
203204
if r.extra == nil {
204-
r.extra = registeredExtras.rules.NilPointer()
205+
r.extra = registeredExtras.rules.Zero()
205206
}
206207
return r.extra
207208
}

params/config.libevm_test.go

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,36 @@ func TestRegisterExtras(t *testing.T) {
5050
name: "Rules payload copied from ChainConfig payload",
5151
register: func() {
5252
RegisterExtras(Extras[ccExtraA, rulesExtraA]{
53-
NewRules: func(cc *ChainConfig, r *Rules, ex *ccExtraA, _ *big.Int, _ bool, _ uint64) *rulesExtraA {
54-
return &rulesExtraA{
53+
NewRules: func(cc *ChainConfig, r *Rules, ex ccExtraA, _ *big.Int, _ bool, _ uint64) rulesExtraA {
54+
return rulesExtraA{
5555
A: ex.A,
5656
}
5757
},
5858
})
5959
},
60-
ccExtra: pseudo.From(&ccExtraA{
60+
ccExtra: pseudo.From(ccExtraA{
6161
A: "hello",
6262
}).Type,
63-
wantRulesExtra: &rulesExtraA{
63+
wantRulesExtra: rulesExtraA{
6464
A: "hello",
6565
},
6666
},
6767
{
68-
name: "no NewForRules() function results in typed but nil pointer",
68+
name: "no NewForRules() function results in zero value",
6969
register: func() {
7070
RegisterExtras(Extras[ccExtraB, rulesExtraB]{})
7171
},
72-
ccExtra: pseudo.From(&ccExtraB{
72+
ccExtra: pseudo.From(ccExtraB{
73+
B: "world",
74+
}).Type,
75+
wantRulesExtra: rulesExtraB{},
76+
},
77+
{
78+
name: "no NewForRules() function results in nil pointer",
79+
register: func() {
80+
RegisterExtras(Extras[ccExtraB, *rulesExtraB]{})
81+
},
82+
ccExtra: pseudo.From(ccExtraB{
7383
B: "world",
7484
}).Type,
7585
wantRulesExtra: (*rulesExtraB)(nil),
@@ -79,10 +89,10 @@ func TestRegisterExtras(t *testing.T) {
7989
register: func() {
8090
RegisterExtras(Extras[rawJSON, struct{ RulesHooks }]{})
8191
},
82-
ccExtra: pseudo.From(&rawJSON{
92+
ccExtra: pseudo.From(rawJSON{
8393
RawMessage: []byte(`"hello, world"`),
8494
}).Type,
85-
wantRulesExtra: (*struct{ RulesHooks })(nil),
95+
wantRulesExtra: struct{ RulesHooks }{},
8696
},
8797
}
8898

@@ -131,7 +141,7 @@ func TestExtrasPanic(t *testing.T) {
131141

132142
assertPanics(
133143
t, func() {
134-
mustBeStruct[int]()
144+
mustBeStructOrPointerToOne[int]()
135145
},
136146
notStructMessage[int](),
137147
)

params/example.libevm_test.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ var getter params.ExtraPayloadGetter[ChainConfigExtra, RulesExtra]
4040
// constructRulesExtra acts as an adjunct to the [params.ChainConfig.Rules]
4141
// method. Its primary purpose is to construct the extra payload for the
4242
// [params.Rules] but it MAY also modify the [params.Rules].
43-
func constructRulesExtra(c *params.ChainConfig, r *params.Rules, cEx *ChainConfigExtra, blockNum *big.Int, isMerge bool, timestamp uint64) *RulesExtra {
44-
return &RulesExtra{
43+
func constructRulesExtra(c *params.ChainConfig, r *params.Rules, cEx ChainConfigExtra, blockNum *big.Int, isMerge bool, timestamp uint64) RulesExtra {
44+
return RulesExtra{
4545
IsMyFork: cEx.MyForkTime != nil && *cEx.MyForkTime <= timestamp,
4646
timestamp: timestamp,
4747
}
@@ -66,12 +66,12 @@ type RulesExtra struct {
6666
}
6767

6868
// FromChainConfig returns the extra payload carried by the ChainConfig.
69-
func FromChainConfig(c *params.ChainConfig) *ChainConfigExtra {
69+
func FromChainConfig(c *params.ChainConfig) ChainConfigExtra {
7070
return getter.FromChainConfig(c)
7171
}
7272

7373
// FromRules returns the extra payload carried by the Rules.
74-
func FromRules(r *params.Rules) *RulesExtra {
74+
func FromRules(r *params.Rules) RulesExtra {
7575
return getter.FromRules(r)
7676
}
7777

@@ -137,16 +137,14 @@ func ExampleExtraPayloadGetter() {
137137
fmt.Println("Chain ID", config.ChainID) // original geth fields work as expected
138138

139139
ccExtra := FromChainConfig(config) // extraparams.FromChainConfig() in practice
140-
if ccExtra != nil && ccExtra.MyForkTime != nil {
140+
if ccExtra.MyForkTime != nil {
141141
fmt.Println("Fork time", *ccExtra.MyForkTime)
142142
}
143143

144144
for _, time := range []uint64{forkTime - 1, forkTime, forkTime + 1} {
145145
rules := config.Rules(nil, false, time)
146146
rExtra := FromRules(&rules) // extraparams.FromRules() in practice
147-
if rExtra != nil {
148-
fmt.Printf("IsMyFork at %v: %t\n", rExtra.timestamp, rExtra.IsMyFork)
149-
}
147+
fmt.Printf("IsMyFork at %v: %t\n", rExtra.timestamp, rExtra.IsMyFork)
150148
}
151149

152150
// Output:

params/json.libevm.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error {
3030
return c.unmarshalJSONWithExtra(data)
3131

3232
case reg != nil && reg.reuseJSONRoot: // although the latter is redundant, it's clearer
33-
c.extra = reg.chainConfig.NilPointer()
33+
c.extra = reg.chainConfig.Zero()
3434
if err := json.Unmarshal(data, c.extra); err != nil {
3535
c.extra = nil
3636
return err
@@ -47,7 +47,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error {
4747
func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error {
4848
cc := &chainConfigWithExportedExtra{
4949
chainConfigWithoutMethods: (*chainConfigWithoutMethods)(c),
50-
Extra: registeredExtras.chainConfig.NilPointer(),
50+
Extra: registeredExtras.chainConfig.Zero(),
5151
}
5252
if err := json.Unmarshal(data, cc); err != nil {
5353
return err

params/json.libevm_test.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestChainConfigJSONRoundTrip(t *testing.T) {
4040
},
4141
},
4242
{
43-
name: "reuse top-level JSON",
43+
name: "reuse top-level JSON with non-pointer",
4444
register: func() {
4545
RegisterExtras(Extras[rootJSONChainConfigExtra, NOOPHooks]{
4646
ReuseJSONRoot: true,
@@ -50,13 +50,29 @@ func TestChainConfigJSONRoundTrip(t *testing.T) {
5050
"chainId": 5678,
5151
"foo": "hello"
5252
}`,
53+
want: &ChainConfig{
54+
ChainID: big.NewInt(5678),
55+
extra: pseudo.From(rootJSONChainConfigExtra{TopLevelFoo: "hello"}).Type,
56+
},
57+
},
58+
{
59+
name: "reuse top-level JSON with pointer",
60+
register: func() {
61+
RegisterExtras(Extras[*rootJSONChainConfigExtra, NOOPHooks]{
62+
ReuseJSONRoot: true,
63+
})
64+
},
65+
jsonInput: `{
66+
"chainId": 5678,
67+
"foo": "hello"
68+
}`,
5369
want: &ChainConfig{
5470
ChainID: big.NewInt(5678),
5571
extra: pseudo.From(&rootJSONChainConfigExtra{TopLevelFoo: "hello"}).Type,
5672
},
5773
},
5874
{
59-
name: "nested JSON",
75+
name: "nested JSON with non-pointer",
6076
register: func() {
6177
RegisterExtras(Extras[nestedChainConfigExtra, NOOPHooks]{
6278
ReuseJSONRoot: false, // explicit zero value only for tests
@@ -66,6 +82,22 @@ func TestChainConfigJSONRoundTrip(t *testing.T) {
6682
"chainId": 42,
6783
"extra": {"foo": "world"}
6884
}`,
85+
want: &ChainConfig{
86+
ChainID: big.NewInt(42),
87+
extra: pseudo.From(nestedChainConfigExtra{NestedFoo: "world"}).Type,
88+
},
89+
},
90+
{
91+
name: "nested JSON with pointer",
92+
register: func() {
93+
RegisterExtras(Extras[*nestedChainConfigExtra, NOOPHooks]{
94+
ReuseJSONRoot: false, // explicit zero value only for tests
95+
})
96+
},
97+
jsonInput: `{
98+
"chainId": 42,
99+
"extra": {"foo": "world"}
100+
}`,
69101
want: &ChainConfig{
70102
ChainID: big.NewInt(42),
71103
extra: pseudo.From(&nestedChainConfigExtra{NestedFoo: "world"}).Type,

0 commit comments

Comments
 (0)