@@ -233,6 +233,49 @@ pub enum CallName {
233
233
UnwrapRight ( ResolvedType ) ,
234
234
/// [`Option::unwrap`].
235
235
Unwrap ,
236
+ /// A custom function that was defined previously.
237
+ ///
238
+ /// We effectively copy the function body into every call of the function.
239
+ /// We use [`Arc`] for cheap clones during this process.
240
+ Custom ( CustomFunction ) ,
241
+ }
242
+
243
+ /// Definition of a custom function.
244
+ #[ derive( Clone , Debug , Eq , PartialEq , Hash ) ]
245
+ pub struct CustomFunction {
246
+ params : Arc < [ FunctionParam ] > ,
247
+ body : Arc < Expression > ,
248
+ }
249
+
250
+ impl CustomFunction {
251
+ /// Access the identifiers of the parameters of the function.
252
+ pub fn params ( & self ) -> & [ FunctionParam ] {
253
+ & self . params
254
+ }
255
+
256
+ /// Access the body of the function.
257
+ pub fn body ( & self ) -> & Expression {
258
+ & self . body
259
+ }
260
+ }
261
+
262
+ /// Parameter of a function.
263
+ #[ derive( Clone , Debug , Eq , PartialEq , Hash ) ]
264
+ pub struct FunctionParam {
265
+ identifier : Identifier ,
266
+ ty : ResolvedType ,
267
+ }
268
+
269
+ impl FunctionParam {
270
+ /// Access the identifier of the parameter.
271
+ pub fn identifier ( & self ) -> & Identifier {
272
+ & self . identifier
273
+ }
274
+
275
+ /// Access the type of the parameter.
276
+ pub fn ty ( & self ) -> & ResolvedType {
277
+ & self . ty
278
+ }
236
279
}
237
280
238
281
/// Match expression.
@@ -286,11 +329,14 @@ impl MatchArm {
286
329
/// 1. Assigning types to each variable
287
330
/// 2. Resolving type aliases
288
331
/// 3. Assigning types to each witness expression
332
+ /// 4. Resolving calls to custom functions
289
333
#[ derive( Clone , Debug , Eq , PartialEq , Default ) ]
290
334
struct Scope {
291
335
variables : Vec < HashMap < Identifier , ResolvedType > > ,
292
336
aliases : HashMap < Identifier , ResolvedType > ,
293
337
witnesses : HashMap < WitnessName , ResolvedType > ,
338
+ functions : HashMap < FunctionName , CustomFunction > ,
339
+ is_main : bool ,
294
340
}
295
341
296
342
impl Scope {
@@ -304,6 +350,19 @@ impl Scope {
304
350
self . variables . push ( HashMap :: new ( ) ) ;
305
351
}
306
352
353
+ /// Push the scope of the main function onto the stack.
354
+ ///
355
+ /// ## Panics
356
+ ///
357
+ /// - The current scope is already inside the main function.
358
+ /// - The current scope is not topmost.
359
+ pub fn push_main_scope ( & mut self ) {
360
+ assert ! ( !self . is_main, "Already inside main function" ) ;
361
+ assert ! ( self . is_topmost( ) , "Current scope is not topmost" ) ;
362
+ self . push_scope ( ) ;
363
+ self . is_main = true ;
364
+ }
365
+
307
366
/// Pop the current scope from the stack.
308
367
///
309
368
/// ## Panics
@@ -313,6 +372,22 @@ impl Scope {
313
372
self . variables . pop ( ) . expect ( "Stack is empty" ) ;
314
373
}
315
374
375
+ /// Pop the scope of the main function from the stack.
376
+ ///
377
+ /// ## Panics
378
+ ///
379
+ /// - The current scope is not inside the main function.
380
+ /// - The current scope is not nested in the topmost scope.
381
+ pub fn pop_main_scope ( & mut self ) {
382
+ assert ! ( self . is_main, "Current scope is not inside main function" ) ;
383
+ self . pop_scope ( ) ;
384
+ self . is_main = false ;
385
+ assert ! (
386
+ self . is_topmost( ) ,
387
+ "Current scope is not nested in topmost scope"
388
+ )
389
+ }
390
+
316
391
/// Push a variable onto the current stack.
317
392
///
318
393
/// ## Panics
@@ -359,9 +434,13 @@ impl Scope {
359
434
///
360
435
/// ## Errors
361
436
///
362
- /// The witness name has already been defined somewhere else in the program .
363
- /// Witness names may be used at most throughout the entire program.
437
+ /// - The current scope is not inside the main function .
438
+ /// - The witness name has already been defined somewhere else in the program.
364
439
pub fn insert_witness ( & mut self , name : WitnessName , ty : ResolvedType ) -> Result < ( ) , Error > {
440
+ if !self . is_main {
441
+ return Err ( Error :: WitnessOutsideMain ) ;
442
+ }
443
+
365
444
match self . witnesses . entry ( name. clone ( ) ) {
366
445
Entry :: Occupied ( _) => Err ( Error :: WitnessReused ( name) ) ,
367
446
Entry :: Vacant ( entry) => {
@@ -377,6 +456,30 @@ impl Scope {
377
456
pub fn into_witnesses ( self ) -> HashMap < WitnessName , ResolvedType > {
378
457
self . witnesses
379
458
}
459
+
460
+ /// Insert a custom function into the global map.
461
+ ///
462
+ /// ## Errors
463
+ ///
464
+ /// The function has already been defined.
465
+ pub fn insert_function (
466
+ & mut self ,
467
+ name : FunctionName ,
468
+ function : CustomFunction ,
469
+ ) -> Result < ( ) , Error > {
470
+ match self . functions . entry ( name. clone ( ) ) {
471
+ Entry :: Occupied ( _) => Err ( Error :: FunctionRedefined ( name) ) ,
472
+ Entry :: Vacant ( entry) => {
473
+ entry. insert ( function) ;
474
+ Ok ( ( ) )
475
+ }
476
+ }
477
+ }
478
+
479
+ /// Get the definition of a custom function.
480
+ pub fn get_function ( & self , name : & FunctionName ) -> Option < & CustomFunction > {
481
+ self . functions . get ( name)
482
+ }
380
483
}
381
484
382
485
/// Part of the abstract syntax tree that can be generated from a precursor in the parse tree.
@@ -452,9 +555,35 @@ impl AbstractSyntaxTree for Function {
452
555
assert ! ( ty. is_unit( ) , "Function definitions cannot return anything" ) ;
453
556
assert ! ( scope. is_topmost( ) , "Items live in the topmost scope only" ) ;
454
557
455
- // TODO: Handle custom functions once we can call them
456
- // Skip custom functions because we cannot call them with the current grammar
457
558
if from. name ( ) . as_inner ( ) != "main" {
559
+ let params = from
560
+ . params ( )
561
+ . iter ( )
562
+ . map ( |param| {
563
+ let identifier = param. identifier ( ) . clone ( ) ;
564
+ let ty = scope. resolve ( param. ty ( ) ) ?;
565
+ Ok ( FunctionParam { identifier, ty } )
566
+ } )
567
+ . collect :: < Result < Arc < [ FunctionParam ] > , Error > > ( )
568
+ . with_span ( from) ?;
569
+ let ret = from
570
+ . ret ( )
571
+ . as_ref ( )
572
+ . map ( |aliased| scope. resolve ( aliased) . with_span ( from) )
573
+ . transpose ( ) ?
574
+ . unwrap_or_else ( ResolvedType :: unit) ;
575
+ scope. push_scope ( ) ;
576
+ for param in params. iter ( ) {
577
+ scope. insert_variable ( param. identifier ( ) . clone ( ) , param. ty ( ) . clone ( ) ) ;
578
+ }
579
+ let body = Expression :: analyze ( from. body ( ) , & ret, scope) . map ( Arc :: new) ?;
580
+ scope. pop_scope ( ) ;
581
+ debug_assert ! ( scope. is_topmost( ) ) ;
582
+ let function = CustomFunction { params, body } ;
583
+ scope
584
+ . insert_function ( from. name ( ) . clone ( ) , function)
585
+ . with_span ( from) ?;
586
+
458
587
return Ok ( Self :: Custom ) ;
459
588
}
460
589
@@ -468,10 +597,9 @@ impl AbstractSyntaxTree for Function {
468
597
}
469
598
}
470
599
471
- scope. push_scope ( ) ;
600
+ scope. push_main_scope ( ) ;
472
601
let body = Expression :: analyze ( from. body ( ) , ty, scope) ?;
473
- scope. pop_scope ( ) ;
474
- debug_assert ! ( scope. is_topmost( ) ) ;
602
+ scope. pop_main_scope ( ) ;
475
603
Ok ( Self :: Main ( body) )
476
604
}
477
605
}
@@ -771,6 +899,25 @@ impl AbstractSyntaxTree for Call {
771
899
scope,
772
900
) ?] )
773
901
}
902
+ CallName :: Custom ( function) => {
903
+ if from. args . len ( ) != function. params ( ) . len ( ) {
904
+ return Err ( Error :: InvalidNumberOfArguments (
905
+ function. params ( ) . len ( ) ,
906
+ from. args . len ( ) ,
907
+ ) )
908
+ . with_span ( from) ;
909
+ }
910
+ let out_ty = function. body ( ) . ty ( ) ;
911
+ if ty != out_ty {
912
+ return Err ( Error :: ExpressionTypeMismatch ( ty. clone ( ) , out_ty. clone ( ) ) )
913
+ . with_span ( from) ;
914
+ }
915
+ from. args
916
+ . iter ( )
917
+ . zip ( function. params . iter ( ) . map ( FunctionParam :: ty) )
918
+ . map ( |( arg_parse, arg_ty) | Expression :: analyze ( arg_parse, arg_ty, scope) )
919
+ . collect :: < Result < Arc < [ Expression ] > , RichError > > ( ) ?
920
+ }
774
921
} ;
775
922
776
923
Ok ( Self {
@@ -804,6 +951,12 @@ impl AbstractSyntaxTree for CallName {
804
951
. map ( Self :: UnwrapRight )
805
952
. with_span ( from) ,
806
953
parse:: CallName :: Unwrap => Ok ( Self :: Unwrap ) ,
954
+ parse:: CallName :: Custom ( name) => scope
955
+ . get_function ( name)
956
+ . cloned ( )
957
+ . map ( Self :: Custom )
958
+ . ok_or ( Error :: FunctionUndefined ( name. clone ( ) ) )
959
+ . with_span ( from) ,
807
960
}
808
961
}
809
962
}
0 commit comments