@@ -25,25 +25,27 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct {
25
25
// NewRules, if non-nil is called at the end of [ChainConfig.Rules] with the
26
26
// newly created [Rules] and other context from the method call. Its
27
27
// returned value will be the extra payload of the [Rules]. If NewRules is
28
- // nil then so too will the [Rules] extra payload be a nil `* R`.
28
+ // nil then so too will the [Rules] extra payload be a zero-value ` R`.
29
29
//
30
30
// NewRules MAY modify the [Rules] but MUST NOT modify the [ChainConfig].
31
- NewRules func (_ * ChainConfig , _ * Rules , _ * C , blockNum * big.Int , isMerge bool , timestamp uint64 ) * R
31
+ // TODO(arr4n): add the [Rules] to the return signature to make it clearer
32
+ // that the caller can modify the generated Rules.
33
+ NewRules func (_ * ChainConfig , _ * Rules , _ C , blockNum * big.Int , isMerge bool , timestamp uint64 ) R
32
34
}
33
35
34
36
// RegisterExtras registers the types `C` and `R` such that they are carried as
35
37
// extra payloads in [ChainConfig] and [Rules] structs, respectively. It is
36
38
// expected to be called in an `init()` function and MUST NOT be called more
37
- // than once. Both `C` and `R` MUST be structs.
39
+ // than once. Both `C` and `R` MUST be structs or pointers to structs .
38
40
//
39
41
// After registration, JSON unmarshalling of a [ChainConfig] will create a new
40
- // `* C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
41
- // will populate the "extra" key with the contents of the `* C`. Both the
42
+ // `C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
43
+ // will populate the "extra" key with the contents of the `C`. Both the
42
44
// [json.Marshaler] and [json.Unmarshaler] interfaces are honoured if
43
45
// implemented by `C` and/or `R.`
44
46
//
45
47
// Calls to [ChainConfig.Rules] will call the `NewRules` function of the
46
- // registered [Extras] to create a new `* R`.
48
+ // registered [Extras] to create a new `R`.
47
49
//
48
50
// The payloads can be accessed via the [ExtraPayloadGetter.FromChainConfig] and
49
51
// [ExtraPayloadGetter.FromRules] methods of the getter returned by
@@ -54,16 +56,16 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo
54
56
if registeredExtras != nil {
55
57
panic ("re-registration of Extras" )
56
58
}
57
- mustBeStruct [C ]()
58
- mustBeStruct [R ]()
59
+ mustBeStructOrPointerToOne [C ]()
60
+ mustBeStructOrPointerToOne [R ]()
59
61
60
62
getter := e .getter ()
61
63
registeredExtras = & extraConstructors {
62
- chainConfig : pseudo .NewConstructor [C ](),
63
- rules : pseudo .NewConstructor [R ](),
64
- reuseJSONRoot : e .ReuseJSONRoot ,
65
- newForRules : e .newForRules ,
66
- getter : getter ,
64
+ newChainConfig : pseudo .NewConstructor [C ](). Zero ,
65
+ newRules : pseudo .NewConstructor [R ](). Zero ,
66
+ reuseJSONRoot : e .ReuseJSONRoot ,
67
+ newForRules : e .newForRules ,
68
+ getter : getter ,
67
69
}
68
70
return getter
69
71
}
@@ -95,9 +97,9 @@ func TestOnlyClearRegisteredExtras() {
95
97
var registeredExtras * extraConstructors
96
98
97
99
type extraConstructors struct {
98
- chainConfig , rules pseudo.Constructor
99
- reuseJSONRoot bool
100
- newForRules func (_ * ChainConfig , _ * Rules , blockNum * big.Int , isMerge bool , timestamp uint64 ) * pseudo.Type
100
+ newChainConfig , newRules func () * pseudo.Type
101
+ reuseJSONRoot bool
102
+ newForRules func (_ * ChainConfig , _ * Rules , blockNum * big.Int , isMerge bool , timestamp uint64 ) * pseudo.Type
101
103
// use top-level hooksFrom<X>() functions instead of these as they handle
102
104
// instances where no [Extras] were registered.
103
105
getter interface {
@@ -108,27 +110,34 @@ type extraConstructors struct {
108
110
109
111
func (e * Extras [C , R ]) newForRules (c * ChainConfig , r * Rules , blockNum * big.Int , isMerge bool , timestamp uint64 ) * pseudo.Type {
110
112
if e .NewRules == nil {
111
- return registeredExtras .rules . NilPointer ()
113
+ return registeredExtras .newRules ()
112
114
}
113
115
rExtra := e .NewRules (c , r , e .getter ().FromChainConfig (c ), blockNum , isMerge , timestamp )
114
116
return pseudo .From (rExtra ).Type
115
117
}
116
118
117
119
func (* Extras [C , R ]) getter () (g ExtraPayloadGetter [C , R ]) { return }
118
120
119
- // mustBeStruct panics if `T` isn't a struct.
120
- func mustBeStruct [T any ]() {
121
+ // mustBeStructOrPointerToOne panics if `T` isn't a struct or a * struct.
122
+ func mustBeStructOrPointerToOne [T any ]() {
121
123
var x T
122
- if k := reflect .TypeOf (x ).Kind (); k != reflect .Struct {
123
- panic (notStructMessage [T ]())
124
+ switch t := reflect .TypeOf (x ); t .Kind () {
125
+ case reflect .Struct :
126
+ return
127
+ case reflect .Pointer :
128
+ if t .Elem ().Kind () == reflect .Struct {
129
+ return
130
+ }
124
131
}
132
+ panic (notStructMessage [T ]())
125
133
}
126
134
127
- // notStructMessage returns the message with which [mustBeStruct] might panic.
128
- // It exists to avoid change-detector tests should the message contents change.
135
+ // notStructMessage returns the message with which [mustBeStructOrPointerToOne]
136
+ // might panic. It exists to avoid change-detector tests should the message
137
+ // contents change.
129
138
func notStructMessage [T any ]() string {
130
139
var x T
131
- return fmt .Sprintf ("%T is not a struct" , x )
140
+ return fmt .Sprintf ("%T is not a struct nor a pointer to a struct " , x )
132
141
}
133
142
134
143
// An ExtraPayloadGettter provides strongly typed access to the extra payloads
@@ -139,33 +148,37 @@ type ExtraPayloadGetter[C ChainConfigHooks, R RulesHooks] struct {
139
148
}
140
149
141
150
// FromChainConfig returns the ChainConfig's extra payload.
142
- func (ExtraPayloadGetter [C , R ]) FromChainConfig (c * ChainConfig ) * C {
143
- return pseudo.MustNewValue [* C ](c .extraPayload ()).Get ()
151
+ func (ExtraPayloadGetter [C , R ]) FromChainConfig (c * ChainConfig ) C {
152
+ return pseudo.MustNewValue [C ](c .extraPayload ()).Get ()
153
+ }
154
+
155
+ // PointerFromChainConfig returns a pointer to the ChainConfig's extra payload.
156
+ // This is guaranteed to be non-nil.
157
+ func (ExtraPayloadGetter [C , R ]) PointerFromChainConfig (c * ChainConfig ) * C {
158
+ return pseudo.MustPointerTo [C ](c .extraPayload ()).Value .Get ()
144
159
}
145
160
146
161
// hooksFromChainConfig is equivalent to FromChainConfig(), but returns an
147
162
// interface instead of the concrete type implementing it; this allows it to be
148
- // used in non-generic code. If the concrete-type value is nil (typically
149
- // because no [Extras] were registered) a [noopHooks] is returned so it can be
150
- // used without nil checks.
163
+ // used in non-generic code.
151
164
func (e ExtraPayloadGetter [C , R ]) hooksFromChainConfig (c * ChainConfig ) ChainConfigHooks {
152
- if h := e .FromChainConfig (c ); h != nil {
153
- return * h
154
- }
155
- return NOOPHooks {}
165
+ return e .FromChainConfig (c )
156
166
}
157
167
158
168
// FromRules returns the Rules' extra payload.
159
- func (ExtraPayloadGetter [C , R ]) FromRules (r * Rules ) * R {
160
- return pseudo.MustNewValue [* R ](r .extraPayload ()).Get ()
169
+ func (ExtraPayloadGetter [C , R ]) FromRules (r * Rules ) R {
170
+ return pseudo.MustNewValue [R ](r .extraPayload ()).Get ()
171
+ }
172
+
173
+ // PointerFromRules returns a pointer to the Rules's extra payload. This is
174
+ // guaranteed to be non-nil.
175
+ func (ExtraPayloadGetter [C , R ]) PointerFromRules (r * Rules ) * R {
176
+ return pseudo.MustPointerTo [R ](r .extraPayload ()).Value .Get ()
161
177
}
162
178
163
179
// hooksFromRules is the [RulesHooks] equivalent of hooksFromChainConfig().
164
180
func (e ExtraPayloadGetter [C , R ]) hooksFromRules (r * Rules ) RulesHooks {
165
- if h := e .FromRules (r ); h != nil {
166
- return * h
167
- }
168
- return NOOPHooks {}
181
+ return e .FromRules (r )
169
182
}
170
183
171
184
// addRulesExtra is called at the end of [ChainConfig.Rules]; it exists to
@@ -189,7 +202,7 @@ func (c *ChainConfig) extraPayload() *pseudo.Type {
189
202
panic (fmt .Sprintf ("%T.ExtraPayload() called before RegisterExtras()" , c ))
190
203
}
191
204
if c .extra == nil {
192
- c .extra = registeredExtras .chainConfig . NilPointer ()
205
+ c .extra = registeredExtras .newChainConfig ()
193
206
}
194
207
return c .extra
195
208
}
@@ -201,7 +214,7 @@ func (r *Rules) extraPayload() *pseudo.Type {
201
214
panic (fmt .Sprintf ("%T.ExtraPayload() called before RegisterExtras()" , r ))
202
215
}
203
216
if r .extra == nil {
204
- r .extra = registeredExtras .rules . NilPointer ()
217
+ r .extra = registeredExtras .newRules ()
205
218
}
206
219
return r .extra
207
220
}
0 commit comments