Skip to content

Commit 9c724d2

Browse files
committed
feat: Add calls to custom functions
Compile the function body with a scope that contains exactly the function parameters as variables (plus the previously defined aliases and functions). Witness expressions are not allowed in functions because calling a function multiple times would result in two witness nodes with the same name in the compiled Simplicity program. Witness expressions can be used in the main function only. An example for how function calls are implemented: fn f(p1: t1, p2: t2, ...) { body } fn main() { let out: Out = f(v1, v2, ...); } This code is inlined by the Simfony compiler to become: fn main() { let out: Out = { let (p1, p2, ...): (t1, t2, ...) = (v1, v2, ...); body } }
1 parent 39649dd commit 9c724d2

File tree

6 files changed

+210
-11
lines changed

6 files changed

+210
-11
lines changed

src/ast.rs

Lines changed: 160 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,49 @@ pub enum CallName {
233233
UnwrapRight(ResolvedType),
234234
/// [`Option::unwrap`].
235235
Unwrap,
236+
/// A custom function that was defined previously.
237+
///
238+
/// We effectively copy the function body into every call of the function.
239+
/// We use [`Arc`] for cheap clones during this process.
240+
Custom(CustomFunction),
241+
}
242+
243+
/// Definition of a custom function.
244+
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
245+
pub struct CustomFunction {
246+
params: Arc<[FunctionParam]>,
247+
body: Arc<Expression>,
248+
}
249+
250+
impl CustomFunction {
251+
/// Access the identifiers of the parameters of the function.
252+
pub fn params(&self) -> &[FunctionParam] {
253+
&self.params
254+
}
255+
256+
/// Access the body of the function.
257+
pub fn body(&self) -> &Expression {
258+
&self.body
259+
}
260+
}
261+
262+
/// Parameter of a function.
263+
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
264+
pub struct FunctionParam {
265+
identifier: Identifier,
266+
ty: ResolvedType,
267+
}
268+
269+
impl FunctionParam {
270+
/// Access the identifier of the parameter.
271+
pub fn identifier(&self) -> &Identifier {
272+
&self.identifier
273+
}
274+
275+
/// Access the type of the parameter.
276+
pub fn ty(&self) -> &ResolvedType {
277+
&self.ty
278+
}
236279
}
237280

238281
/// Match expression.
@@ -286,11 +329,14 @@ impl MatchArm {
286329
/// 1. Assigning types to each variable
287330
/// 2. Resolving type aliases
288331
/// 3. Assigning types to each witness expression
332+
/// 4. Resolving calls to custom functions
289333
#[derive(Clone, Debug, Eq, PartialEq, Default)]
290334
struct Scope {
291335
variables: Vec<HashMap<Identifier, ResolvedType>>,
292336
aliases: HashMap<Identifier, ResolvedType>,
293337
witnesses: HashMap<WitnessName, ResolvedType>,
338+
functions: HashMap<FunctionName, CustomFunction>,
339+
is_main: bool,
294340
}
295341

296342
impl Scope {
@@ -304,6 +350,19 @@ impl Scope {
304350
self.variables.push(HashMap::new());
305351
}
306352

353+
/// Push the scope of the main function onto the stack.
354+
///
355+
/// ## Panics
356+
///
357+
/// - The current scope is already inside the main function.
358+
/// - The current scope is not topmost.
359+
pub fn push_main_scope(&mut self) {
360+
assert!(!self.is_main, "Already inside main function");
361+
assert!(self.is_topmost(), "Current scope is not topmost");
362+
self.push_scope();
363+
self.is_main = true;
364+
}
365+
307366
/// Pop the current scope from the stack.
308367
///
309368
/// ## Panics
@@ -313,6 +372,22 @@ impl Scope {
313372
self.variables.pop().expect("Stack is empty");
314373
}
315374

375+
/// Pop the scope of the main function from the stack.
376+
///
377+
/// ## Panics
378+
///
379+
/// - The current scope is not inside the main function.
380+
/// - The current scope is not nested in the topmost scope.
381+
pub fn pop_main_scope(&mut self) {
382+
assert!(self.is_main, "Current scope is not inside main function");
383+
self.pop_scope();
384+
self.is_main = false;
385+
assert!(
386+
self.is_topmost(),
387+
"Current scope is not nested in topmost scope"
388+
)
389+
}
390+
316391
/// Push a variable onto the current stack.
317392
///
318393
/// ## Panics
@@ -359,9 +434,13 @@ impl Scope {
359434
///
360435
/// ## Errors
361436
///
362-
/// The witness name has already been defined somewhere else in the program.
363-
/// Witness names may be used at most throughout the entire program.
437+
/// - The current scope is not inside the main function.
438+
/// - The witness name has already been defined somewhere else in the program.
364439
pub fn insert_witness(&mut self, name: WitnessName, ty: ResolvedType) -> Result<(), Error> {
440+
if !self.is_main {
441+
return Err(Error::WitnessOutsideMain);
442+
}
443+
365444
match self.witnesses.entry(name.clone()) {
366445
Entry::Occupied(_) => Err(Error::WitnessReused(name)),
367446
Entry::Vacant(entry) => {
@@ -377,6 +456,30 @@ impl Scope {
377456
pub fn into_witnesses(self) -> HashMap<WitnessName, ResolvedType> {
378457
self.witnesses
379458
}
459+
460+
/// Insert a custom function into the global map.
461+
///
462+
/// ## Errors
463+
///
464+
/// The function has already been defined.
465+
pub fn insert_function(
466+
&mut self,
467+
name: FunctionName,
468+
function: CustomFunction,
469+
) -> Result<(), Error> {
470+
match self.functions.entry(name.clone()) {
471+
Entry::Occupied(_) => Err(Error::FunctionRedefined(name)),
472+
Entry::Vacant(entry) => {
473+
entry.insert(function);
474+
Ok(())
475+
}
476+
}
477+
}
478+
479+
/// Get the definition of a custom function.
480+
pub fn get_function(&self, name: &FunctionName) -> Option<&CustomFunction> {
481+
self.functions.get(name)
482+
}
380483
}
381484

382485
/// Part of the abstract syntax tree that can be generated from a precursor in the parse tree.
@@ -452,9 +555,35 @@ impl AbstractSyntaxTree for Function {
452555
assert!(ty.is_unit(), "Function definitions cannot return anything");
453556
assert!(scope.is_topmost(), "Items live in the topmost scope only");
454557

455-
// TODO: Handle custom functions once we can call them
456-
// Skip custom functions because we cannot call them with the current grammar
457558
if from.name().as_inner() != "main" {
559+
let params = from
560+
.params()
561+
.iter()
562+
.map(|param| {
563+
let identifier = param.identifier().clone();
564+
let ty = scope.resolve(param.ty())?;
565+
Ok(FunctionParam { identifier, ty })
566+
})
567+
.collect::<Result<Arc<[FunctionParam]>, Error>>()
568+
.with_span(from)?;
569+
let ret = from
570+
.ret()
571+
.as_ref()
572+
.map(|aliased| scope.resolve(aliased).with_span(from))
573+
.transpose()?
574+
.unwrap_or_else(ResolvedType::unit);
575+
scope.push_scope();
576+
for param in params.iter() {
577+
scope.insert_variable(param.identifier().clone(), param.ty().clone());
578+
}
579+
let body = Expression::analyze(from.body(), &ret, scope).map(Arc::new)?;
580+
scope.pop_scope();
581+
debug_assert!(scope.is_topmost());
582+
let function = CustomFunction { params, body };
583+
scope
584+
.insert_function(from.name().clone(), function)
585+
.with_span(from)?;
586+
458587
return Ok(Self::Custom);
459588
}
460589

@@ -468,10 +597,9 @@ impl AbstractSyntaxTree for Function {
468597
}
469598
}
470599

471-
scope.push_scope();
600+
scope.push_main_scope();
472601
let body = Expression::analyze(from.body(), ty, scope)?;
473-
scope.pop_scope();
474-
debug_assert!(scope.is_topmost());
602+
scope.pop_main_scope();
475603
Ok(Self::Main(body))
476604
}
477605
}
@@ -771,6 +899,25 @@ impl AbstractSyntaxTree for Call {
771899
scope,
772900
)?])
773901
}
902+
CallName::Custom(function) => {
903+
if from.args.len() != function.params().len() {
904+
return Err(Error::InvalidNumberOfArguments(
905+
function.params().len(),
906+
from.args.len(),
907+
))
908+
.with_span(from);
909+
}
910+
let out_ty = function.body().ty();
911+
if ty != out_ty {
912+
return Err(Error::ExpressionTypeMismatch(ty.clone(), out_ty.clone()))
913+
.with_span(from);
914+
}
915+
from.args
916+
.iter()
917+
.zip(function.params.iter().map(FunctionParam::ty))
918+
.map(|(arg_parse, arg_ty)| Expression::analyze(arg_parse, arg_ty, scope))
919+
.collect::<Result<Arc<[Expression]>, RichError>>()?
920+
}
774921
};
775922

776923
Ok(Self {
@@ -804,6 +951,12 @@ impl AbstractSyntaxTree for CallName {
804951
.map(Self::UnwrapRight)
805952
.with_span(from),
806953
parse::CallName::Unwrap => Ok(Self::Unwrap),
954+
parse::CallName::Custom(name) => scope
955+
.get_function(name)
956+
.cloned()
957+
.map(Self::Custom)
958+
.ok_or(Error::FunctionUndefined(name.clone()))
959+
.with_span(from),
807960
}
808961
}
809962
}

src/compile.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use simplicity::{Cmr, FailEntropy};
88

99
use crate::array::{BTreeSlice, Partition};
1010
use crate::ast::{
11-
Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression,
11+
Call, CallName, Expression, ExpressionInner, FunctionParam, Match, Program, SingleExpression,
1212
SingleExpressionInner, Statement,
1313
};
1414
use crate::error::{Error, RichError, WithSpan};
@@ -301,6 +301,19 @@ impl Call {
301301
let get_inner = ProgNode::assertr_take(fail_cmr, &ProgNode::iden());
302302
ProgNode::comp(&right_and_unit, &get_inner).with_span(self)
303303
}
304+
CallName::Custom(function) => {
305+
let params_pattern = Pattern::tuple(
306+
function
307+
.params()
308+
.iter()
309+
.map(FunctionParam::identifier)
310+
.cloned()
311+
.map(Pattern::Identifier),
312+
);
313+
let mut function_scope = Scope::new(params_pattern);
314+
let body = function.body().compile(&mut function_scope)?;
315+
ProgNode::comp(&args, &body).with_span(self)
316+
}
304317
}
305318
}
306319
}

src/error.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ pub enum Error {
158158
MainNoOutput,
159159
MainRequired,
160160
FunctionRedefined(FunctionName),
161+
FunctionUndefined(FunctionName),
161162
InvalidNumberOfArguments(usize, usize),
162163
ExpressionTypeMismatch(ResolvedType, ResolvedType),
163164
ExpressionNotConstant,
@@ -168,6 +169,7 @@ pub enum Error {
168169
WitnessReused(WitnessName),
169170
WitnessTypeMismatch(WitnessName, ResolvedType, ResolvedType),
170171
WitnessReassigned(WitnessName),
172+
WitnessOutsideMain,
171173
}
172174

173175
#[rustfmt::skip]
@@ -226,6 +228,10 @@ impl fmt::Display for Error {
226228
f,
227229
"Function `{name}` was defined multiple times"
228230
),
231+
Error::FunctionUndefined(name) => write!(
232+
f,
233+
"Function `{name}` was called but not defined"
234+
),
229235
Error::InvalidNumberOfArguments(expected, found) => write!(
230236
f,
231237
"Expected {expected} arguments, found {found} arguments"
@@ -265,7 +271,11 @@ impl fmt::Display for Error {
265271
Error::WitnessReassigned(name) => write!(
266272
f,
267273
"Witness `{name}` has already been assigned a value"
268-
)
274+
),
275+
Error::WitnessOutsideMain => write!(
276+
f,
277+
"Witness expressions are not allowed outside the `main` function"
278+
),
269279
}
270280
}
271281
}

src/minimal.pest

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ witness_name = @{ (ASCII_ALPHANUMERIC | "_")+ }
1313
builtin_type = @{ "Either" | "Option" | "bool" | unsigned_type | "List" }
1414
reserved = @{ jet | builtin_type }
1515

16-
function_name = { identifier }
16+
builtin_function = @{ jet | "unwrap_left" | "unwrap_right" | "unwrap" }
17+
function_name = { !builtin_function ~ identifier }
1718
typed_identifier = { identifier ~ ":" ~ ty }
1819
function_params = { "(" ~ (typed_identifier ~ ("," ~ typed_identifier)*)? ~ ")" }
1920
function_return = { "->" ~ ty }
@@ -57,7 +58,7 @@ true_expr = @{ "true" }
5758
unwrap_left = ${ "unwrap_left::<" ~ ty ~ ">" }
5859
unwrap_right = ${ "unwrap_right::<" ~ ty ~ ">" }
5960
unwrap = @{ "unwrap" }
60-
call_name = ${ jet | unwrap_left | unwrap_right | unwrap }
61+
call_name = ${ jet | unwrap_left | unwrap_right | unwrap | function_name }
6162
call_args = { "(" ~ (expression ~ ("," ~ expression)*)? ~ ")" }
6263
call_expr = { call_name ~ call_args }
6364
unsigned_decimal = @{ (ASCII_DIGIT | "_")+ }

src/parse.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ pub enum CallName {
278278
UnwrapRight(AliasedType),
279279
/// Some unwrap function.
280280
Unwrap,
281+
/// Name of a custom function.
282+
Custom(FunctionName),
281283
}
282284

283285
/// Name of a jet.
@@ -811,6 +813,7 @@ impl PestParse for CallName {
811813
AliasedType::parse(inner).map(Self::UnwrapRight)
812814
}
813815
Rule::unwrap => Ok(Self::Unwrap),
816+
Rule::function_name => FunctionName::parse(pair).map(Self::Custom),
814817
_ => panic!("Corrupt grammar"),
815818
}
816819
}

src/witness.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,4 +267,23 @@ mod tests {
267267
),
268268
}
269269
}
270+
271+
#[test]
272+
fn witness_outside_main() {
273+
let s = r#"fn f() -> u32 {
274+
witness("output_of_f")
275+
}
276+
277+
fn main() {
278+
jet_verify(jet_is_zero_32(f()));
279+
}"#;
280+
281+
match crate::compile(s) {
282+
Ok(_) => panic!("Witness outside main was falsely accepted"),
283+
Err(error) => {
284+
assert!(error
285+
.contains("Witness expressions are not allowed outside the `main` function"))
286+
}
287+
}
288+
}
270289
}

0 commit comments

Comments
 (0)