Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Access to nonconstant fixed columns. #2368

Merged
merged 9 commits into from
Jan 23, 2025
20 changes: 11 additions & 9 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use std::collections::HashSet;

use bit_vec::BitVec;
use itertools::Itertools;
use powdr_ast::analyzed::AlgebraicReference;
use powdr_ast::analyzed::{PolyID, PolynomialType};
use powdr_number::FieldElement;

use crate::witgen::{jit::processor::Processor, machines::MachineParts, FixedData};

use super::{
effect::Effect,
variable::Variable,
variable::{Cell, Variable},
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
};

Expand Down Expand Up @@ -123,17 +123,19 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
}

impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {
fn evaluate(&self, var: &AlgebraicReference, row_offset: i32) -> Option<T> {
assert!(var.is_fixed());
let values = self.fixed_data.fixed_cols[&var.poly_id].values_max_size();
fn evaluate(&self, fixed_cell: &Cell) -> Option<T> {
let poly_id = PolyID {
id: fixed_cell.id,
ptype: PolynomialType::Constant,
};
let values = self.fixed_data.fixed_cols[&poly_id].values_max_size();

// By assumption of the block machine, all fixed columns are cyclic with a period of <block_size>.
// An exception might be the first and last row.
assert!(row_offset >= -1);
assert!(fixed_cell.row_offset >= -1);
assert!(self.block_size >= 1);
// The current row is guaranteed to be at least 1.
let current_row = (2 * self.block_size as i32 + row_offset) as usize;
let row = current_row + var.next as usize;
// The row is guaranteed to be at least 1.
let row = (2 * self.block_size as i32 + fixed_cell.row_offset) as usize;

assert!(values.len() >= self.block_size * 4);

Expand Down
135 changes: 127 additions & 8 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::{cmp::Ordering, ffi::c_void, mem, sync::Arc};

use itertools::Itertools;
use libloading::Library;
use powdr_ast::indent;
use powdr_ast::{
analyzed::{PolyID, PolynomialType},
indent,
};
use powdr_number::{FieldElement, KnownField};

use crate::witgen::{
Expand All @@ -11,7 +14,7 @@ use crate::witgen::{
profiling::{record_end, record_start},
LookupCell,
},
QueryCallback,
FixedData, QueryCallback,
};

use super::{
Expand All @@ -35,6 +38,7 @@ impl<T: FieldElement> WitgenFunction<T> {
/// This function always succeeds (unless it panics).
pub fn call<Q: QueryCallback<T>>(
&self,
fixed_data: &FixedData<'_, T>,
mutable_state: &MutableState<'_, T, Q>,
params: &mut [LookupCell<T>],
mut data: CompactDataRef<'_, T>,
Expand All @@ -48,10 +52,26 @@ impl<T: FieldElement> WitgenFunction<T> {
params: params.into(),
mutable_state: mutable_state as *const _ as *const c_void,
call_machine: call_machine::<T, Q>,
fixed_data: fixed_data as *const _ as *const c_void,
get_fixed_value: get_fixed_value::<T>,
});
}
}

extern "C" fn get_fixed_value<T: FieldElement>(
fixed_data: *const c_void,
column: u64,
row: u64,
) -> T {
let fixed_data = unsafe { &*(fixed_data as *const FixedData<'_, T>) };
let poly_id = PolyID {
id: column,
ptype: PolynomialType::Constant,
};
// TODO which size?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it always work to use the largest one?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be fine, if some outer code still decides when to down-size and handles the corner cases of the last row. We do most of the computation assuming we'll pick the largest size, end then downscale and full up the remaining rows.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that's not the issue - the question is: Can we pick the largest size and assume it's the same value in the other sizes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the assumption we currently go with already ;) We could assert that as well.

fixed_data.fixed_cols[&poly_id].values_max_size()[row as usize]
}

extern "C" fn call_machine<T: FieldElement, Q: QueryCallback<T>>(
mutable_state: *const c_void,
identity_id: u64,
Expand Down Expand Up @@ -105,6 +125,11 @@ struct WitgenFunctionParams<'a, T: 'a> {
mutable_state: *const c_void,
/// A callback to call submachines.
call_machine: extern "C" fn(*const c_void, u64, MutSlice<LookupCell<'_, T>>) -> bool,
/// A pointer to the "fixed data".
fixed_data: *const c_void,
/// A callback to retrieve values from fixed columns.
/// The parameters are: fixed data pointer, fixed column id, row number.
get_fixed_value: extern "C" fn(*const c_void, u64, u64) -> T,
}

#[repr(C)]
Expand Down Expand Up @@ -150,13 +175,35 @@ fn witgen_code<T: FieldElement>(
format!("get(data, row_offset, {}, {})", c.row_offset, c.id)
}
Variable::Param(i) => format!("get_param(params, {i})"),
Variable::FixedColumn(_) => panic!("Fixed columns should not be known inputs."),
Variable::MachineCallParam(_) => {
unreachable!("Machine call variables should not be pre-known.")
}
};
format!(" let {var_name} = {value};")
})
.format("\n");

// Pre-load all the fixed columns so that we can treat them as
// plain variables later.
let load_fixed = effects
.iter()
.flat_map(|e| e.referenced_variables())
.filter_map(|v| match v {
Variable::FixedColumn(c) => Some((v, c)),
_ => None,
})
.unique()
.map(|(var, cell)| {
format!(
" let {} = get_fixed_value(fixed_data, {}, (row_offset + {}));",
variable_to_string(var),
cell.id,
cell.row_offset,
)
})
.format("\n");

let main_code = format_effects(effects);
let vars_known = effects
.iter()
Expand All @@ -173,6 +220,7 @@ fn witgen_code<T: FieldElement>(
cell.row_offset, cell.id,
)),
Variable::Param(i) => Some(format!(" set_param(params, {i}, {value});")),
Variable::FixedColumn(_) => panic!("Fixed columns should not be written to."),
Variable::MachineCallParam(_) => {
// This is just an internal variable.
None
Expand All @@ -186,7 +234,7 @@ fn witgen_code<T: FieldElement>(
.iter()
.filter_map(|var| match var {
Variable::Cell(cell) => Some(cell),
Variable::Param(_) | Variable::MachineCallParam(_) => None,
Variable::Param(_) | Variable::FixedColumn(_) | Variable::MachineCallParam(_) => None,
})
.map(|cell| {
format!(
Expand All @@ -205,19 +253,28 @@ extern "C" fn witgen(
row_offset,
params,
mutable_state,
call_machine
call_machine,
fixed_data,
get_fixed_value,
}}: WitgenFunctionParams<FieldElement>,
) {{
let known = known_to_slice(known, data.len);
let data = data.to_mut_slice();
let params = params.to_mut_slice();

// Pre-load fixed column values into local variables
{load_fixed}

// Load all known inputs into local variables
{load_known_inputs}

// Perform the main computations
{main_code}

// Store the newly derived witness cell values
{store_values}

// Store the "known" flags
{store_known}
}}
"#
Expand Down Expand Up @@ -373,6 +430,14 @@ fn variable_to_string(v: &Variable) -> String {
format_row_offset(cell.row_offset)
),
Variable::Param(i) => format!("p_{i}"),
Variable::FixedColumn(cell) => {
format!(
"f_{}_{}_{}",
escape_column_name(&cell.column_name),
cell.id,
cell.row_offset
)
}
Variable::MachineCallParam(call_var) => {
format!(
"call_var_{}_{}_{}",
Expand Down Expand Up @@ -461,6 +526,8 @@ fn util_code<T: FieldElement>(first_column_id: u64, column_count: usize) -> Resu
#[cfg(test)]
mod tests {

use std::ptr::null;

use pretty_assertions::assert_eq;
use test_log::test;

Expand Down Expand Up @@ -546,8 +613,8 @@ mod tests {
let known_inputs = vec![a0.clone()];
let code = witgen_code(&known_inputs, &effects);
assert_eq!(
code,
"
code,
"
#[no_mangle]
extern \"C\" fn witgen(
WitgenFunctionParams{
Expand All @@ -556,15 +623,22 @@ extern \"C\" fn witgen(
row_offset,
params,
mutable_state,
call_machine
call_machine,
fixed_data,
get_fixed_value,
}: WitgenFunctionParams<FieldElement>,
) {
let known = known_to_slice(known, data.len);
let data = data.to_mut_slice();
let params = params.to_mut_slice();

// Pre-load fixed column values into local variables


// Load all known inputs into local variables
let c_a_2_0 = get(data, row_offset, 0, 2);

// Perform the main computations
let c_x_0_0 = (FieldElement::from(7) * c_a_2_0);
let call_var_7_1_0 = c_x_0_0;
let mut call_var_7_1_1 = FieldElement::default();
Expand All @@ -573,16 +647,18 @@ extern \"C\" fn witgen(
let c_y_1_1 = (c_y_1_m1 + c_x_0_0);
assert!(c_y_1_m1 == c_x_0_0);

// Store the newly derived witness cell values
set(data, row_offset, 0, 0, c_x_0_0);
set(data, row_offset, -1, 1, c_y_1_m1);
set(data, row_offset, 1, 1, c_y_1_1);

// Store the \"known\" flags
set_known(known, row_offset, 0, 0);
set_known(known, row_offset, -1, 1);
set_known(known, row_offset, 1, 1);
}
"
);
);
}

extern "C" fn no_call_machine(
Expand All @@ -593,6 +669,10 @@ extern \"C\" fn witgen(
false
}

extern "C" fn get_fixed_data_test(_: *const c_void, col_id: u64, row: u64) -> GoldilocksField {
GoldilocksField::from(col_id * 2000 + row)
}

fn witgen_fun_params<'a>(
data: &mut [GoldilocksField],
known: &mut [u32],
Expand All @@ -604,6 +684,8 @@ extern \"C\" fn witgen(
params: Default::default(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
fixed_data: null(),
get_fixed_value: get_fixed_data_test,
}
}

Expand Down Expand Up @@ -655,6 +737,8 @@ extern \"C\" fn witgen(
params: Default::default(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
fixed_data: null(),
get_fixed_value: get_fixed_data_test,
};
(f2.function)(params2);
assert_eq!(data[0], GoldilocksField::from(7));
Expand Down Expand Up @@ -746,6 +830,8 @@ extern \"C\" fn witgen(
params: params.as_mut_slice().into(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
fixed_data: null(),
get_fixed_value: get_fixed_data_test,
};
(f.function)(params);
assert_eq!(y_val, GoldilocksField::from(7 * 2));
Expand All @@ -768,6 +854,33 @@ extern \"C\" fn witgen(
assert!(code.contains(&format!("let c_x_1_0 = (c_a_0_0 & {large_num:#x});")));
}

#[test]
fn fixed_column_access() {
let a = cell("a", 0, 0);
let x = Variable::FixedColumn(Cell {
column_name: "X".to_string(),
id: 15,
row_offset: 6,
});
let effects = vec![assignment(&a, symbol(&x))];
let f = compile_effects(0, 1, &[], &effects).unwrap();
let mut data = vec![7.into()];
let mut known = vec![0];
let mut params = vec![];
let params = WitgenFunctionParams {
data: data.as_mut_slice().into(),
known: known.as_mut_ptr(),
row_offset: 0,
params: params.as_mut_slice().into(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
fixed_data: null(),
get_fixed_value: get_fixed_data_test,
};
(f.function)(params);
assert_eq!(data[0], GoldilocksField::from(30006));
}

extern "C" fn mock_call_machine(
_: *const c_void,
id: u64,
Expand Down Expand Up @@ -820,6 +933,8 @@ extern \"C\" fn witgen(
params: Default::default(),
mutable_state: std::ptr::null(),
call_machine: mock_call_machine,
fixed_data: null(),
get_fixed_value: get_fixed_data_test,
};
(f.function)(params);
assert_eq!(data[0], GoldilocksField::from(9));
Expand Down Expand Up @@ -854,6 +969,8 @@ extern \"C\" fn witgen(
params: params.as_mut_slice().into(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
fixed_data: null(),
get_fixed_value: get_fixed_data_test,
};
(f.function)(params);
assert_eq!(y_val, GoldilocksField::from(8));
Expand All @@ -867,6 +984,8 @@ extern \"C\" fn witgen(
params: params.as_mut_slice().into(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
fixed_data: null(),
get_fixed_value: get_fixed_data_test,
};
(f.function)(params);
assert_eq!(y_val, GoldilocksField::from(4));
Expand Down
4 changes: 3 additions & 1 deletion executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct CacheKey {
}

pub struct FunctionCache<'a, T: FieldElement> {
fixed_data: &'a FixedData<'a, T>,
/// The processor that generates the JIT code
processor: BlockMachineProcessor<'a, T>,
/// The cache of JIT functions. If the entry is None, we attempted to generate the function
Expand All @@ -47,6 +48,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
BlockMachineProcessor::new(fixed_data, parts.clone(), block_size, latch_row);

FunctionCache {
fixed_data,
processor,
column_layout: metadata,
witgen_functions: HashMap::new(),
Expand Down Expand Up @@ -167,7 +169,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
.expect("Need to call compile_cached() first!")
.as_ref()
.expect("compile_cached() returned false!");
f.call(mutable_state, values, data);
f.call(self.fixed_data, mutable_state, values, data);

Ok(true)
}
Expand Down
Loading
Loading