Skip to content

Commit fd54bf0

Browse files
committed
merge target
2 parents 908238f + be32fff commit fd54bf0

20 files changed

+566
-160
lines changed

ast/src/analyzed/mod.rs

+1-9
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,7 @@ impl<T> Analyzed<T> {
9191

9292
/// Returns the set of all referenced challenges in this [`Analyzed<T>`].
9393
pub fn challenges(&self) -> BTreeSet<&Challenge> {
94-
self.identities
95-
.iter()
96-
.flat_map(|identity| identity.all_children())
97-
.chain(
98-
// Note: we iterate on a `HashMap` so the ordering is not guaranteed, but this is ok since we're building another map.
99-
self.intermediate_columns
100-
.values()
101-
.flat_map(|(_, def)| def.iter().flat_map(|d| d.all_children())),
102-
)
94+
self.all_children()
10395
.filter_map(|expr| match expr {
10496
AlgebraicExpression::Challenge(challenge) => Some(challenge),
10597
_ => None,

backend/src/plonky3/stark.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -431,13 +431,13 @@ mod tests {
431431
}
432432

433433
#[test]
434-
fn add() {
434+
fn mul() {
435435
let content = r#"
436436
namespace Add(8);
437437
col witness x;
438438
col witness y;
439439
col witness z;
440-
x + y = z;
440+
x * y = z;
441441
"#;
442442
run_test(content);
443443
}

backend/src/stwo/circuit_builder.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ impl FrameworkEval for PowdrEval {
312312
}
313313
}
314314

315-
// This function creates a list of the names of the constant polynomials that have next references
316-
// Note that the anaylsis should also dereference next references to intermediate polynomials
315+
/// This function creates a list of the names of the constant polynomials that have next references
316+
/// Note that the anaylsis should also dereference next references to intermediate polynomials
317317
pub fn get_constant_with_next_list(analyzed: &Analyzed<M31>) -> HashSet<String> {
318318
let intermediate_definitions = analyzed.intermediate_definitions();
319319
let cache = &mut BTreeMap::new();

executor/src/witgen/jit/compiler.rs

+170-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::witgen::{
1414
finalizable_data::{ColumnLayout, CompactDataRef},
1515
mutable_state::MutableState,
1616
},
17-
jit::prover_function_heuristics::ProverFunctionComputation,
17+
jit::prover_function_heuristics::{ProverFunctionComputation, QueryType},
1818
machines::{
1919
profiling::{record_end, record_start},
2020
LookupCell,
@@ -60,6 +60,8 @@ impl<T: FieldElement> WitgenFunction<T> {
6060
call_machine: call_machine::<T, Q>,
6161
fixed_data: fixed_data as *const _ as *const c_void,
6262
get_fixed_value: get_fixed_value::<T>,
63+
input_from_channel: input_from_channel::<T, Q>,
64+
output_to_channel: output_to_channel::<T, Q>,
6365
});
6466
}
6567
}
@@ -88,6 +90,30 @@ extern "C" fn call_machine<T: FieldElement, Q: QueryCallback<T>>(
8890
.unwrap()
8991
}
9092

93+
extern "C" fn input_from_channel<T: FieldElement, Q: QueryCallback<T>>(
94+
mutable_state: *const c_void,
95+
channel: u32,
96+
index: u64,
97+
) -> T {
98+
let mutable_state = unsafe { &*(mutable_state as *const MutableState<T, Q>) };
99+
// TODO what is the proper error handling?
100+
// TODO What to do for Ok(None)?
101+
(mutable_state.query_callback())(&format!("Input({channel},{index})"))
102+
.unwrap()
103+
.unwrap()
104+
}
105+
106+
extern "C" fn output_to_channel<T: FieldElement, Q: QueryCallback<T>>(
107+
mutable_state: *const c_void,
108+
fd: u32,
109+
elem: T,
110+
) {
111+
let mutable_state = unsafe { &*(mutable_state as *const MutableState<T, Q>) };
112+
(mutable_state.query_callback())(&format!("Output({fd},{elem})"))
113+
.unwrap()
114+
.unwrap();
115+
}
116+
91117
/// Compile the given inferred effects into machine code and load it.
92118
pub fn compile_effects<T: FieldElement, D: DefinitionFetcher>(
93119
definitions: &D,
@@ -154,6 +180,12 @@ struct WitgenFunctionParams<'a, T: 'a> {
154180
/// A callback to retrieve values from fixed columns.
155181
/// The parameters are: fixed data pointer, fixed column id, row number.
156182
get_fixed_value: extern "C" fn(*const c_void, u64, u64) -> T,
183+
/// A callback to retrieve a prover-provided value from a channel
184+
/// The parameters are: mutable state pointer, channel number, index.
185+
input_from_channel: extern "C" fn(*const c_void, u32, u64) -> T,
186+
/// A callback to output a value to a channel.
187+
/// The parameters are: mutable state pointer, channel number, value.
188+
output_to_channel: extern "C" fn(*const c_void, u32, T),
157189
}
158190

159191
#[repr(C)]
@@ -290,6 +322,8 @@ extern "C" fn witgen(
290322
call_machine,
291323
fixed_data,
292324
get_fixed_value,
325+
input_from_channel,
326+
output_to_channel,
293327
}}: WitgenFunctionParams<FieldElement>,
294328
) {{
295329
let known = known_to_slice(known, data.len);
@@ -385,7 +419,7 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
385419
inputs,
386420
}) => {
387421
format!(
388-
"{}[{}] = prover_function_{function_index}(row_offset + {row_offset}, &[{}]);",
422+
"{}[{}] = prover_function_{function_index}(mutable_state, input_from_channel, output_to_channel, row_offset + {row_offset}, &[{}]);",
389423
if is_top_level { "let " } else { "" },
390424
targets.iter().map(variable_to_string).format(", "),
391425
inputs.iter().map(variable_to_string).format(", ")
@@ -489,13 +523,13 @@ fn variable_to_string(v: &Variable) -> String {
489523
"f_{}_{}_{}",
490524
escape_column_name(&cell.column_name),
491525
cell.id,
492-
cell.row_offset
526+
format_row_offset(cell.row_offset)
493527
),
494528
Variable::IntermediateCell(cell) => format!(
495529
"i_{}_{}_{}",
496530
escape_column_name(&cell.column_name),
497531
cell.id,
498-
cell.row_offset
532+
format_row_offset(cell.row_offset)
499533
),
500534
Variable::MachineCallParam(call_var) => {
501535
format!(
@@ -544,7 +578,7 @@ fn prover_function_code<T: FieldElement, D: DefinitionFetcher>(
544578
f: &ProverFunction<'_, T>,
545579
codegen: &mut CodeGenerator<'_, T, D>,
546580
) -> Result<String, String> {
547-
let code = match f.computation {
581+
let code = match &f.computation {
548582
ProverFunctionComputation::ComputeFrom(code) => format!(
549583
"({}).call(args.to_vec().into())",
550584
codegen.generate_code_for_expression(code)?
@@ -553,6 +587,28 @@ fn prover_function_code<T: FieldElement, D: DefinitionFetcher>(
553587
assert!(!f.compute_multi);
554588
format!("({}).call()", codegen.generate_code_for_expression(code)?)
555589
}
590+
ProverFunctionComputation::HandleQueryInputOutput(branches) => {
591+
let indent = " ";
592+
// We assign zero in the "no match" case. The correct behaviour would be to
593+
// not assign anything, but it should work for all our use-cases.
594+
format!(
595+
"match IntType::from(args[0]) {{\n{}\n{indent}_ => 0.into(),\n }}",
596+
branches
597+
.iter()
598+
.map(|(value, query_type)| {
599+
let result = match query_type {
600+
QueryType::Input => {
601+
"input_from_channel(mutable_state, IntType::from(args[0]) as u32, IntType::from(args[1]) as u64),"
602+
}
603+
QueryType::Output => {
604+
"{ output_to_channel(mutable_state, IntType::from(args[0]) as u32, args[1]); 0.into() },"
605+
}
606+
};
607+
format!("{indent}{value} => {result}")
608+
})
609+
.format("\n")
610+
)
611+
}
556612
};
557613
let code = if f.compute_multi {
558614
format!("({code}).as_slice().try_into().unwrap()")
@@ -563,10 +619,16 @@ fn prover_function_code<T: FieldElement, D: DefinitionFetcher>(
563619
let length = f.target.len();
564620
let index = f.index;
565621
Ok(format!(
566-
"fn prover_function_{index}(i: u64, args: &[FieldElement]) -> [FieldElement; {length}] {{\n\
567-
let i: ibig::IBig = i.into();\n\
568-
{code}
569-
}}"
622+
r#"fn prover_function_{index}(
623+
mutable_state: *const std::ffi::c_void,
624+
input_from_channel: extern "C" fn(*const std::ffi::c_void, u32, u64) -> FieldElement,
625+
output_to_channel: extern "C" fn(*const std::ffi::c_void, u32, FieldElement),
626+
i: u64,
627+
args: &[FieldElement]
628+
) -> [FieldElement; {length}] {{
629+
let i: ibig::IBig = i.into();
630+
{code}
631+
}}"#
570632
))
571633
}
572634

@@ -575,12 +637,14 @@ mod tests {
575637

576638
use std::ptr::null;
577639

640+
use powdr_ast::analyzed::AlgebraicReference;
578641
use powdr_ast::analyzed::FunctionValueDefinition;
579642
use pretty_assertions::assert_eq;
580643
use test_log::test;
581644

582645
use powdr_number::GoldilocksField;
583646

647+
use crate::witgen::jit::prover_function_heuristics::QueryType;
584648
use crate::witgen::jit::variable::Cell;
585649
use crate::witgen::jit::variable::MachineCallVariable;
586650
use crate::witgen::range_constraints::RangeConstraint;
@@ -698,6 +762,8 @@ extern \"C\" fn witgen(
698762
call_machine,
699763
fixed_data,
700764
get_fixed_value,
765+
input_from_channel,
766+
output_to_channel,
701767
}: WitgenFunctionParams<FieldElement>,
702768
) {
703769
let known = known_to_slice(known, data.len);
@@ -745,6 +811,12 @@ extern \"C\" fn witgen(
745811
GoldilocksField::from(col_id * 2000 + row)
746812
}
747813

814+
extern "C" fn input_from_channel_test(_: *const c_void, _: u32, _: u64) -> GoldilocksField {
815+
GoldilocksField::from(117)
816+
}
817+
818+
extern "C" fn output_to_channel_test(_: *const c_void, _: u32, _: GoldilocksField) {}
819+
748820
fn witgen_fun_params<'a>(
749821
data: &mut [GoldilocksField],
750822
known: &mut [u32],
@@ -758,6 +830,8 @@ extern \"C\" fn witgen(
758830
call_machine: no_call_machine,
759831
fixed_data: null(),
760832
get_fixed_value: get_fixed_data_test,
833+
input_from_channel: input_from_channel_test,
834+
output_to_channel: output_to_channel_test,
761835
}
762836
}
763837

@@ -775,6 +849,8 @@ extern \"C\" fn witgen(
775849
call_machine: no_call_machine,
776850
fixed_data: null(),
777851
get_fixed_value: get_fixed_data_test,
852+
input_from_channel: input_from_channel_test,
853+
output_to_channel: output_to_channel_test,
778854
}
779855
}
780856

@@ -828,6 +904,8 @@ extern \"C\" fn witgen(
828904
call_machine: no_call_machine,
829905
fixed_data: null(),
830906
get_fixed_value: get_fixed_data_test,
907+
input_from_channel: input_from_channel_test,
908+
output_to_channel: output_to_channel_test,
831909
};
832910
(f2.function)(params2);
833911
assert_eq!(data[0], GoldilocksField::from(7));
@@ -1005,6 +1083,8 @@ extern \"C\" fn witgen(
10051083
call_machine: mock_call_machine,
10061084
fixed_data: null(),
10071085
get_fixed_value: get_fixed_data_test,
1086+
input_from_channel: input_from_channel_test,
1087+
output_to_channel: output_to_channel_test,
10081088
};
10091089
(f.function)(params);
10101090
assert_eq!(data[0], GoldilocksField::from(9));
@@ -1072,4 +1152,85 @@ extern \"C\" fn witgen(
10721152
}";
10731153
assert_eq!(format_effects(&[branch_effect]), expectation);
10741154
}
1155+
1156+
#[test]
1157+
fn handle_query_prover_function() {
1158+
fn to_algebraic_ref(name: &str, id: u64) -> AlgebraicReference {
1159+
AlgebraicReference {
1160+
name: name.to_string(),
1161+
poly_id: PolyID {
1162+
id,
1163+
ptype: PolynomialType::Committed,
1164+
},
1165+
next: false,
1166+
}
1167+
}
1168+
1169+
let x = cell("x", 0, 0);
1170+
let y = cell("y", 1, 0);
1171+
let z = cell("z", 2, 0);
1172+
let effects = vec![Effect::ProverFunctionCall(ProverFunctionCall {
1173+
targets: vec![x.clone()],
1174+
function_index: 0,
1175+
row_offset: 0,
1176+
inputs: vec![y.clone(), z.clone()],
1177+
})];
1178+
let known_inputs = vec![y.clone(), z.clone()];
1179+
let prover_function = ProverFunction {
1180+
index: 0,
1181+
target: vec![to_algebraic_ref("x", 0)],
1182+
compute_multi: false,
1183+
computation: ProverFunctionComputation::HandleQueryInputOutput(
1184+
[(7, QueryType::Input), (8, QueryType::Output)]
1185+
.into_iter()
1186+
.collect(),
1187+
),
1188+
condition: None,
1189+
input_columns: vec![to_algebraic_ref("y", 1), to_algebraic_ref("z", 2)],
1190+
};
1191+
let f = super::compile_effects(
1192+
&NoDefinitions,
1193+
ColumnLayout {
1194+
column_count: 3,
1195+
first_column_id: 0,
1196+
},
1197+
&known_inputs,
1198+
&effects,
1199+
vec![prover_function],
1200+
)
1201+
.unwrap();
1202+
1203+
let mut data = vec![
1204+
GoldilocksField::from(0),
1205+
GoldilocksField::from(7),
1206+
GoldilocksField::from(2),
1207+
];
1208+
let mut known = vec![0; 1];
1209+
(f.function)(witgen_fun_params(&mut data, &mut known));
1210+
assert_eq!(data[0], GoldilocksField::from(117));
1211+
assert_eq!(data[1], GoldilocksField::from(7));
1212+
assert_eq!(data[2], GoldilocksField::from(2));
1213+
1214+
let mut data = vec![
1215+
GoldilocksField::from(0),
1216+
GoldilocksField::from(8),
1217+
GoldilocksField::from(2),
1218+
];
1219+
let mut known = vec![0; 1];
1220+
(f.function)(witgen_fun_params(&mut data, &mut known));
1221+
assert_eq!(data[0], GoldilocksField::from(0));
1222+
assert_eq!(data[1], GoldilocksField::from(8));
1223+
assert_eq!(data[2], GoldilocksField::from(2));
1224+
1225+
let mut data = vec![
1226+
GoldilocksField::from(0),
1227+
GoldilocksField::from(9),
1228+
GoldilocksField::from(2),
1229+
];
1230+
let mut known = vec![0; 1];
1231+
(f.function)(witgen_fun_params(&mut data, &mut known));
1232+
assert_eq!(data[0], GoldilocksField::from(0));
1233+
assert_eq!(data[1], GoldilocksField::from(9));
1234+
assert_eq!(data[2], GoldilocksField::from(2));
1235+
}
10751236
}

executor/src/witgen/jit/includes/interface.rs

+2
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,6 @@ pub struct WitgenFunctionParams<'a, T: 'a> {
104104
call_machine: extern "C" fn(*const std::ffi::c_void, u64, MutSlice<LookupCell<'_, T>>) -> bool,
105105
fixed_data: *const std::ffi::c_void,
106106
get_fixed_value: extern "C" fn(*const std::ffi::c_void, u64, u64) -> T,
107+
input_from_channel: extern "C" fn(*const std::ffi::c_void, u32, u64) -> T,
108+
output_to_channel: extern "C" fn(*const std::ffi::c_void, u32, T),
107109
}

0 commit comments

Comments
 (0)