1
1
use super :: SimpleSignature ;
2
- use crate :: types:: checkpoint:: EpochId ;
2
+ use crate :: types:: { checkpoint:: EpochId , u256 :: U256 } ;
3
3
4
4
/// An zk login authenticator with all the necessary fields.
5
5
#[ derive( Debug , Clone , PartialEq , Eq ) ]
@@ -19,7 +19,7 @@ pub struct ZkLoginInputs {
19
19
proof_points : ZkLoginProof ,
20
20
iss_base64_details : Claim ,
21
21
header_base64 : String ,
22
- address_seed : String ,
22
+ address_seed : AddressSeed ,
23
23
// #[serde(skip)]
24
24
// jwt_details: JwtDetails,
25
25
}
@@ -73,8 +73,7 @@ pub type CircomG2 = Vec<Vec<String>>;
73
73
#[ derive( Clone , Debug , PartialEq , Eq ) ]
74
74
pub struct ZkLoginPublicIdentifier {
75
75
iss : String ,
76
- //TODO bigint support
77
- address_seed : [ u8 ; 32 ] ,
76
+ address_seed : AddressSeed ,
78
77
}
79
78
80
79
/// Struct that contains info for a JWK. A list of them for different kids can
@@ -109,6 +108,107 @@ pub struct JwkId {
109
108
pub kid : String ,
110
109
}
111
110
111
+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
112
+ pub struct AddressSeed ( [ u8 ; 32 ] ) ;
113
+
114
+ impl AddressSeed {
115
+ pub fn unpadded ( & self ) -> & [ u8 ] {
116
+ let mut buf = self . 0 . as_slice ( ) ;
117
+
118
+ while !buf. is_empty ( ) && buf[ 0 ] == 0 {
119
+ buf = & buf[ 1 ..] ;
120
+ }
121
+
122
+ // If the value is '0' then just return a slice of length 1 of the final byte
123
+ if buf. is_empty ( ) {
124
+ & self . 0 [ 31 ..]
125
+ } else {
126
+ buf
127
+ }
128
+ }
129
+
130
+ pub fn padded ( & self ) -> & [ u8 ] {
131
+ & self . 0
132
+ }
133
+ }
134
+
135
+ impl std:: fmt:: Display for AddressSeed {
136
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
137
+ let u256 = U256 :: from_be ( U256 :: from_digits ( self . 0 ) ) ;
138
+ let radix10 = u256. to_str_radix ( 10 ) ;
139
+ f. write_str ( & radix10)
140
+ }
141
+ }
142
+
143
+ #[ derive( Debug ) ]
144
+ pub struct AddressSeedParseError ( bnum:: errors:: ParseIntError ) ;
145
+
146
+ impl std:: fmt:: Display for AddressSeedParseError {
147
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
148
+ write ! ( f, "unable to parse radix10 encoded value {}" , self . 0 )
149
+ }
150
+ }
151
+
152
+ impl std:: error:: Error for AddressSeedParseError { }
153
+
154
+ impl std:: str:: FromStr for AddressSeed {
155
+ type Err = AddressSeedParseError ;
156
+
157
+ fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
158
+ let u256 = U256 :: from_str_radix ( s, 10 ) . map_err ( AddressSeedParseError ) ?;
159
+ let be = u256. to_be ( ) ;
160
+ Ok ( Self ( * be. digits ( ) ) )
161
+ }
162
+ }
163
+
164
+ #[ cfg( test) ]
165
+ mod test {
166
+ use super :: AddressSeed ;
167
+ use num_bigint:: BigUint ;
168
+ use proptest:: prelude:: * ;
169
+ use std:: str:: FromStr ;
170
+
171
+ #[ cfg( target_arch = "wasm32" ) ]
172
+ use wasm_bindgen_test:: wasm_bindgen_test as test;
173
+
174
+ #[ test]
175
+ fn unpadded_slice ( ) {
176
+ let seed = AddressSeed ( [ 0 ; 32 ] ) ;
177
+ let zero: [ u8 ; 1 ] = [ 0 ] ;
178
+ assert_eq ! ( seed. unpadded( ) , zero. as_slice( ) ) ;
179
+
180
+ let mut seed = AddressSeed ( [ 1 ; 32 ] ) ;
181
+ seed. 0 [ 0 ] = 0 ;
182
+ assert_eq ! ( seed. unpadded( ) , [ 1 ; 31 ] . as_slice( ) ) ;
183
+ }
184
+
185
+ proptest ! {
186
+ #[ test]
187
+ fn dont_crash_on_large_inputs(
188
+ bytes in proptest:: collection:: vec( any:: <u8 >( ) , 33 ..1024 )
189
+ ) {
190
+ let big_int = BigUint :: from_bytes_be( & bytes) ;
191
+ let radix10 = big_int. to_str_radix( 10 ) ;
192
+
193
+ // doesn't crash
194
+ let _ = AddressSeed :: from_str( & radix10) ;
195
+ }
196
+
197
+ #[ test]
198
+ fn valid_address_seeds(
199
+ bytes in proptest:: collection:: vec( any:: <u8 >( ) , 1 ..=32 )
200
+ ) {
201
+ let big_int = BigUint :: from_bytes_be( & bytes) ;
202
+ let radix10 = big_int. to_str_radix( 10 ) ;
203
+
204
+ let seed = AddressSeed :: from_str( & radix10) . unwrap( ) ;
205
+ assert_eq!( radix10, seed. to_string( ) ) ;
206
+ // Ensure unpadded doesn't crash
207
+ seed. unpadded( ) ;
208
+ }
209
+ }
210
+ }
211
+
112
212
#[ cfg( feature = "serde" ) ]
113
213
#[ cfg_attr( doc_cfg, doc( cfg( feature = "serde" ) ) ) ]
114
214
mod serialization {
@@ -121,6 +221,7 @@ mod serialization {
121
221
use serde:: Serializer ;
122
222
use serde_with:: Bytes ;
123
223
use serde_with:: DeserializeAs ;
224
+ use serde_with:: SerializeAs ;
124
225
use std:: borrow:: Cow ;
125
226
126
227
// Serialized format is: iss_bytes_len || iss_bytes || padded_32_byte_address_seed.
@@ -133,16 +234,11 @@ mod serialization {
133
234
#[ derive( serde_derive:: Serialize ) ]
134
235
struct Readable < ' a > {
135
236
iss : & ' a str ,
136
- //TODO this needs to be encoded as a Decimal u256 instead of in base64
137
- #[ cfg_attr(
138
- feature = "serde" ,
139
- serde( with = "::serde_with::As::<crate::types::crypto::Base64Array32>" )
140
- ) ]
141
- address_seed : [ u8 ; 32 ] ,
237
+ address_seed : & ' a AddressSeed ,
142
238
}
143
239
let readable = Readable {
144
240
iss : & self . iss ,
145
- address_seed : self . address_seed ,
241
+ address_seed : & self . address_seed ,
146
242
} ;
147
243
readable. serialize ( serializer)
148
244
} else {
@@ -151,7 +247,7 @@ mod serialization {
151
247
buf. push ( iss_bytes. len ( ) as u8 ) ;
152
248
buf. extend ( iss_bytes) ;
153
249
154
- buf. extend ( & self . address_seed ) ;
250
+ buf. extend ( & self . address_seed . 0 ) ;
155
251
156
252
serializer. serialize_bytes ( & buf)
157
253
}
@@ -167,12 +263,7 @@ mod serialization {
167
263
#[ derive( serde_derive:: Deserialize ) ]
168
264
struct Readable {
169
265
iss : String ,
170
- //TODO this needs to be encoded as a Decimal u256 instead of in base64
171
- #[ cfg_attr(
172
- feature = "serde" ,
173
- serde( with = "::serde_with::As::<crate::types::crypto::Base64Array32>" )
174
- ) ]
175
- address_seed : [ u8 ; 32 ] ,
266
+ address_seed : AddressSeed ,
176
267
}
177
268
178
269
let Readable { iss, address_seed } = Deserialize :: deserialize ( deserializer) ?;
@@ -188,8 +279,9 @@ mod serialization {
188
279
. get ( ( 1 + iss_len as usize ) ..)
189
280
. ok_or_else ( || serde:: de:: Error :: custom ( "invalid zklogin public identifier" ) ) ?;
190
281
191
- let address_seed =
192
- <[ u8 ; 32 ] >:: try_from ( address_seed_bytes) . map_err ( serde:: de:: Error :: custom) ?;
282
+ let address_seed = <[ u8 ; 32 ] >:: try_from ( address_seed_bytes)
283
+ . map_err ( serde:: de:: Error :: custom)
284
+ . map ( AddressSeed ) ?;
193
285
194
286
Ok ( Self {
195
287
iss : iss. into ( ) ,
@@ -282,4 +374,23 @@ mod serialization {
282
374
} )
283
375
}
284
376
}
377
+
378
+ // AddressSeed's serialized format is as a radix10 encoded string
379
+ impl Serialize for AddressSeed {
380
+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
381
+ where
382
+ S : serde:: Serializer ,
383
+ {
384
+ serde_with:: DisplayFromStr :: serialize_as ( self , serializer)
385
+ }
386
+ }
387
+
388
+ impl < ' de > Deserialize < ' de > for AddressSeed {
389
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
390
+ where
391
+ D : Deserializer < ' de > ,
392
+ {
393
+ serde_with:: DisplayFromStr :: deserialize_as ( deserializer)
394
+ }
395
+ }
285
396
}
0 commit comments