@@ -14,7 +14,7 @@ use crate::witgen::{
14
14
finalizable_data:: { ColumnLayout , CompactDataRef } ,
15
15
mutable_state:: MutableState ,
16
16
} ,
17
- jit:: prover_function_heuristics:: ProverFunctionComputation ,
17
+ jit:: prover_function_heuristics:: { ProverFunctionComputation , QueryType } ,
18
18
machines:: {
19
19
profiling:: { record_end, record_start} ,
20
20
LookupCell ,
@@ -60,6 +60,8 @@ impl<T: FieldElement> WitgenFunction<T> {
60
60
call_machine : call_machine :: < T , Q > ,
61
61
fixed_data : fixed_data as * const _ as * const c_void ,
62
62
get_fixed_value : get_fixed_value :: < T > ,
63
+ input_from_channel : input_from_channel :: < T , Q > ,
64
+ output_to_channel : output_to_channel :: < T , Q > ,
63
65
} ) ;
64
66
}
65
67
}
@@ -88,6 +90,30 @@ extern "C" fn call_machine<T: FieldElement, Q: QueryCallback<T>>(
88
90
. unwrap ( )
89
91
}
90
92
93
+ extern "C" fn input_from_channel < T : FieldElement , Q : QueryCallback < T > > (
94
+ mutable_state : * const c_void ,
95
+ channel : u32 ,
96
+ index : u64 ,
97
+ ) -> T {
98
+ let mutable_state = unsafe { & * ( mutable_state as * const MutableState < T , Q > ) } ;
99
+ // TODO what is the proper error handling?
100
+ // TODO What to do for Ok(None)?
101
+ ( mutable_state. query_callback ( ) ) ( & format ! ( "Input({channel},{index})" ) )
102
+ . unwrap ( )
103
+ . unwrap ( )
104
+ }
105
+
106
+ extern "C" fn output_to_channel < T : FieldElement , Q : QueryCallback < T > > (
107
+ mutable_state : * const c_void ,
108
+ fd : u32 ,
109
+ elem : T ,
110
+ ) {
111
+ let mutable_state = unsafe { & * ( mutable_state as * const MutableState < T , Q > ) } ;
112
+ ( mutable_state. query_callback ( ) ) ( & format ! ( "Output({fd},{elem})" ) )
113
+ . unwrap ( )
114
+ . unwrap ( ) ;
115
+ }
116
+
91
117
/// Compile the given inferred effects into machine code and load it.
92
118
pub fn compile_effects < T : FieldElement , D : DefinitionFetcher > (
93
119
definitions : & D ,
@@ -154,6 +180,12 @@ struct WitgenFunctionParams<'a, T: 'a> {
154
180
/// A callback to retrieve values from fixed columns.
155
181
/// The parameters are: fixed data pointer, fixed column id, row number.
156
182
get_fixed_value : extern "C" fn ( * const c_void , u64 , u64 ) -> T ,
183
+ /// A callback to retrieve a prover-provided value from a channel
184
+ /// The parameters are: mutable state pointer, channel number, index.
185
+ input_from_channel : extern "C" fn ( * const c_void , u32 , u64 ) -> T ,
186
+ /// A callback to output a value to a channel.
187
+ /// The parameters are: mutable state pointer, channel number, value.
188
+ output_to_channel : extern "C" fn ( * const c_void , u32 , T ) ,
157
189
}
158
190
159
191
#[ repr( C ) ]
@@ -290,6 +322,8 @@ extern "C" fn witgen(
290
322
call_machine,
291
323
fixed_data,
292
324
get_fixed_value,
325
+ input_from_channel,
326
+ output_to_channel,
293
327
}}: WitgenFunctionParams<FieldElement>,
294
328
) {{
295
329
let known = known_to_slice(known, data.len);
@@ -385,7 +419,7 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
385
419
inputs,
386
420
} ) => {
387
421
format ! (
388
- "{}[{}] = prover_function_{function_index}(row_offset + {row_offset}, &[{}]);" ,
422
+ "{}[{}] = prover_function_{function_index}(mutable_state, input_from_channel, output_to_channel, row_offset + {row_offset}, &[{}]);" ,
389
423
if is_top_level { "let " } else { "" } ,
390
424
targets. iter( ) . map( variable_to_string) . format( ", " ) ,
391
425
inputs. iter( ) . map( variable_to_string) . format( ", " )
@@ -489,13 +523,13 @@ fn variable_to_string(v: &Variable) -> String {
489
523
"f_{}_{}_{}" ,
490
524
escape_column_name( & cell. column_name) ,
491
525
cell. id,
492
- cell. row_offset
526
+ format_row_offset ( cell. row_offset)
493
527
) ,
494
528
Variable :: IntermediateCell ( cell) => format ! (
495
529
"i_{}_{}_{}" ,
496
530
escape_column_name( & cell. column_name) ,
497
531
cell. id,
498
- cell. row_offset
532
+ format_row_offset ( cell. row_offset)
499
533
) ,
500
534
Variable :: MachineCallParam ( call_var) => {
501
535
format ! (
@@ -544,7 +578,7 @@ fn prover_function_code<T: FieldElement, D: DefinitionFetcher>(
544
578
f : & ProverFunction < ' _ , T > ,
545
579
codegen : & mut CodeGenerator < ' _ , T , D > ,
546
580
) -> Result < String , String > {
547
- let code = match f. computation {
581
+ let code = match & f. computation {
548
582
ProverFunctionComputation :: ComputeFrom ( code) => format ! (
549
583
"({}).call(args.to_vec().into())" ,
550
584
codegen. generate_code_for_expression( code) ?
@@ -553,6 +587,28 @@ fn prover_function_code<T: FieldElement, D: DefinitionFetcher>(
553
587
assert ! ( !f. compute_multi) ;
554
588
format ! ( "({}).call()" , codegen. generate_code_for_expression( code) ?)
555
589
}
590
+ ProverFunctionComputation :: HandleQueryInputOutput ( branches) => {
591
+ let indent = " " ;
592
+ // We assign zero in the "no match" case. The correct behaviour would be to
593
+ // not assign anything, but it should work for all our use-cases.
594
+ format ! (
595
+ "match IntType::from(args[0]) {{\n {}\n {indent}_ => 0.into(),\n }}" ,
596
+ branches
597
+ . iter( )
598
+ . map( |( value, query_type) | {
599
+ let result = match query_type {
600
+ QueryType :: Input => {
601
+ "input_from_channel(mutable_state, IntType::from(args[0]) as u32, IntType::from(args[1]) as u64),"
602
+ }
603
+ QueryType :: Output => {
604
+ "{ output_to_channel(mutable_state, IntType::from(args[0]) as u32, args[1]); 0.into() },"
605
+ }
606
+ } ;
607
+ format!( "{indent}{value} => {result}" )
608
+ } )
609
+ . format( "\n " )
610
+ )
611
+ }
556
612
} ;
557
613
let code = if f. compute_multi {
558
614
format ! ( "({code}).as_slice().try_into().unwrap()" )
@@ -563,10 +619,16 @@ fn prover_function_code<T: FieldElement, D: DefinitionFetcher>(
563
619
let length = f. target . len ( ) ;
564
620
let index = f. index ;
565
621
Ok ( format ! (
566
- "fn prover_function_{index}(i: u64, args: &[FieldElement]) -> [FieldElement; {length}] {{\n \
567
- let i: ibig::IBig = i.into();\n \
568
- {code}
569
- }}"
622
+ r#"fn prover_function_{index}(
623
+ mutable_state: *const std::ffi::c_void,
624
+ input_from_channel: extern "C" fn(*const std::ffi::c_void, u32, u64) -> FieldElement,
625
+ output_to_channel: extern "C" fn(*const std::ffi::c_void, u32, FieldElement),
626
+ i: u64,
627
+ args: &[FieldElement]
628
+ ) -> [FieldElement; {length}] {{
629
+ let i: ibig::IBig = i.into();
630
+ {code}
631
+ }}"#
570
632
) )
571
633
}
572
634
@@ -575,12 +637,14 @@ mod tests {
575
637
576
638
use std:: ptr:: null;
577
639
640
+ use powdr_ast:: analyzed:: AlgebraicReference ;
578
641
use powdr_ast:: analyzed:: FunctionValueDefinition ;
579
642
use pretty_assertions:: assert_eq;
580
643
use test_log:: test;
581
644
582
645
use powdr_number:: GoldilocksField ;
583
646
647
+ use crate :: witgen:: jit:: prover_function_heuristics:: QueryType ;
584
648
use crate :: witgen:: jit:: variable:: Cell ;
585
649
use crate :: witgen:: jit:: variable:: MachineCallVariable ;
586
650
use crate :: witgen:: range_constraints:: RangeConstraint ;
@@ -698,6 +762,8 @@ extern \"C\" fn witgen(
698
762
call_machine,
699
763
fixed_data,
700
764
get_fixed_value,
765
+ input_from_channel,
766
+ output_to_channel,
701
767
}: WitgenFunctionParams<FieldElement>,
702
768
) {
703
769
let known = known_to_slice(known, data.len);
@@ -745,6 +811,12 @@ extern \"C\" fn witgen(
745
811
GoldilocksField :: from ( col_id * 2000 + row)
746
812
}
747
813
814
+ extern "C" fn input_from_channel_test ( _: * const c_void , _: u32 , _: u64 ) -> GoldilocksField {
815
+ GoldilocksField :: from ( 117 )
816
+ }
817
+
818
+ extern "C" fn output_to_channel_test ( _: * const c_void , _: u32 , _: GoldilocksField ) { }
819
+
748
820
fn witgen_fun_params < ' a > (
749
821
data : & mut [ GoldilocksField ] ,
750
822
known : & mut [ u32 ] ,
@@ -758,6 +830,8 @@ extern \"C\" fn witgen(
758
830
call_machine : no_call_machine,
759
831
fixed_data : null ( ) ,
760
832
get_fixed_value : get_fixed_data_test,
833
+ input_from_channel : input_from_channel_test,
834
+ output_to_channel : output_to_channel_test,
761
835
}
762
836
}
763
837
@@ -775,6 +849,8 @@ extern \"C\" fn witgen(
775
849
call_machine : no_call_machine,
776
850
fixed_data : null ( ) ,
777
851
get_fixed_value : get_fixed_data_test,
852
+ input_from_channel : input_from_channel_test,
853
+ output_to_channel : output_to_channel_test,
778
854
}
779
855
}
780
856
@@ -828,6 +904,8 @@ extern \"C\" fn witgen(
828
904
call_machine : no_call_machine,
829
905
fixed_data : null ( ) ,
830
906
get_fixed_value : get_fixed_data_test,
907
+ input_from_channel : input_from_channel_test,
908
+ output_to_channel : output_to_channel_test,
831
909
} ;
832
910
( f2. function ) ( params2) ;
833
911
assert_eq ! ( data[ 0 ] , GoldilocksField :: from( 7 ) ) ;
@@ -1005,6 +1083,8 @@ extern \"C\" fn witgen(
1005
1083
call_machine : mock_call_machine,
1006
1084
fixed_data : null ( ) ,
1007
1085
get_fixed_value : get_fixed_data_test,
1086
+ input_from_channel : input_from_channel_test,
1087
+ output_to_channel : output_to_channel_test,
1008
1088
} ;
1009
1089
( f. function ) ( params) ;
1010
1090
assert_eq ! ( data[ 0 ] , GoldilocksField :: from( 9 ) ) ;
@@ -1072,4 +1152,85 @@ extern \"C\" fn witgen(
1072
1152
}" ;
1073
1153
assert_eq ! ( format_effects( & [ branch_effect] ) , expectation) ;
1074
1154
}
1155
+
1156
+ #[ test]
1157
+ fn handle_query_prover_function ( ) {
1158
+ fn to_algebraic_ref ( name : & str , id : u64 ) -> AlgebraicReference {
1159
+ AlgebraicReference {
1160
+ name : name. to_string ( ) ,
1161
+ poly_id : PolyID {
1162
+ id,
1163
+ ptype : PolynomialType :: Committed ,
1164
+ } ,
1165
+ next : false ,
1166
+ }
1167
+ }
1168
+
1169
+ let x = cell ( "x" , 0 , 0 ) ;
1170
+ let y = cell ( "y" , 1 , 0 ) ;
1171
+ let z = cell ( "z" , 2 , 0 ) ;
1172
+ let effects = vec ! [ Effect :: ProverFunctionCall ( ProverFunctionCall {
1173
+ targets: vec![ x. clone( ) ] ,
1174
+ function_index: 0 ,
1175
+ row_offset: 0 ,
1176
+ inputs: vec![ y. clone( ) , z. clone( ) ] ,
1177
+ } ) ] ;
1178
+ let known_inputs = vec ! [ y. clone( ) , z. clone( ) ] ;
1179
+ let prover_function = ProverFunction {
1180
+ index : 0 ,
1181
+ target : vec ! [ to_algebraic_ref( "x" , 0 ) ] ,
1182
+ compute_multi : false ,
1183
+ computation : ProverFunctionComputation :: HandleQueryInputOutput (
1184
+ [ ( 7 , QueryType :: Input ) , ( 8 , QueryType :: Output ) ]
1185
+ . into_iter ( )
1186
+ . collect ( ) ,
1187
+ ) ,
1188
+ condition : None ,
1189
+ input_columns : vec ! [ to_algebraic_ref( "y" , 1 ) , to_algebraic_ref( "z" , 2 ) ] ,
1190
+ } ;
1191
+ let f = super :: compile_effects (
1192
+ & NoDefinitions ,
1193
+ ColumnLayout {
1194
+ column_count : 3 ,
1195
+ first_column_id : 0 ,
1196
+ } ,
1197
+ & known_inputs,
1198
+ & effects,
1199
+ vec ! [ prover_function] ,
1200
+ )
1201
+ . unwrap ( ) ;
1202
+
1203
+ let mut data = vec ! [
1204
+ GoldilocksField :: from( 0 ) ,
1205
+ GoldilocksField :: from( 7 ) ,
1206
+ GoldilocksField :: from( 2 ) ,
1207
+ ] ;
1208
+ let mut known = vec ! [ 0 ; 1 ] ;
1209
+ ( f. function ) ( witgen_fun_params ( & mut data, & mut known) ) ;
1210
+ assert_eq ! ( data[ 0 ] , GoldilocksField :: from( 117 ) ) ;
1211
+ assert_eq ! ( data[ 1 ] , GoldilocksField :: from( 7 ) ) ;
1212
+ assert_eq ! ( data[ 2 ] , GoldilocksField :: from( 2 ) ) ;
1213
+
1214
+ let mut data = vec ! [
1215
+ GoldilocksField :: from( 0 ) ,
1216
+ GoldilocksField :: from( 8 ) ,
1217
+ GoldilocksField :: from( 2 ) ,
1218
+ ] ;
1219
+ let mut known = vec ! [ 0 ; 1 ] ;
1220
+ ( f. function ) ( witgen_fun_params ( & mut data, & mut known) ) ;
1221
+ assert_eq ! ( data[ 0 ] , GoldilocksField :: from( 0 ) ) ;
1222
+ assert_eq ! ( data[ 1 ] , GoldilocksField :: from( 8 ) ) ;
1223
+ assert_eq ! ( data[ 2 ] , GoldilocksField :: from( 2 ) ) ;
1224
+
1225
+ let mut data = vec ! [
1226
+ GoldilocksField :: from( 0 ) ,
1227
+ GoldilocksField :: from( 9 ) ,
1228
+ GoldilocksField :: from( 2 ) ,
1229
+ ] ;
1230
+ let mut known = vec ! [ 0 ; 1 ] ;
1231
+ ( f. function ) ( witgen_fun_params ( & mut data, & mut known) ) ;
1232
+ assert_eq ! ( data[ 0 ] , GoldilocksField :: from( 0 ) ) ;
1233
+ assert_eq ! ( data[ 1 ] , GoldilocksField :: from( 9 ) ) ;
1234
+ assert_eq ! ( data[ 2 ] , GoldilocksField :: from( 2 ) ) ;
1235
+ }
1075
1236
}
0 commit comments