Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jit interpreter branch handling #2481

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
19 changes: 8 additions & 11 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,15 @@ use super::{
variable::Variable,
};

pub struct WitgenFunction<T> {
pub struct CompiledFunction<T> {
// TODO We might want to pass arguments as direct function parameters
// (instead of a struct), so that
// they are stored in registers instead of the stack. Should be checked.
function: extern "C" fn(WitgenFunctionParams<T>),
_library: Arc<Library>,
}

impl<T: FieldElement> WitgenFunction<T> {
/// Call the witgen function to fill the data and "known" tables
/// given a slice of parameters.
/// The `row_offset` is the index inside `data` of the row considered to be "row zero".
/// This function always succeeds (unless it panics).
impl<T: FieldElement> CompiledFunction<T> {
pub fn call<Q: QueryCallback<T>>(
&self,
fixed_data: &FixedData<'_, T>,
Expand All @@ -51,7 +47,7 @@ impl<T: FieldElement> WitgenFunction<T> {
) {
let row_offset = data.row_offset.try_into().unwrap();
let (data, known) = data.as_mut_slices();
(self.function)(WitgenFunctionParams {
let params = WitgenFunctionParams {
data: data.into(),
known: known.as_mut_ptr(),
row_offset,
Expand All @@ -62,7 +58,8 @@ impl<T: FieldElement> WitgenFunction<T> {
get_fixed_value: get_fixed_value::<T>,
input_from_channel: input_from_channel::<T, Q>,
output_to_channel: output_to_channel::<T, Q>,
});
};
(self.function)(params);
}
}

Expand Down Expand Up @@ -119,7 +116,7 @@ pub fn compile_effects<T: FieldElement, D: DefinitionFetcher>(
known_inputs: &[Variable],
effects: &[Effect<T, Variable>],
prover_functions: Vec<ProverFunction<'_, T>>,
) -> Result<WitgenFunction<T>, String> {
) -> Result<CompiledFunction<T>, String> {
let utils = util_code::<T>()?;
let interface = interface_code(column_layout);
let mut codegen = CodeGenerator::<T, _>::new(definitions);
Expand Down Expand Up @@ -153,7 +150,7 @@ pub fn compile_effects<T: FieldElement, D: DefinitionFetcher>(

let library = Arc::new(unsafe { libloading::Library::new(&lib_path.path).unwrap() });
let witgen_fun = unsafe { library.get(b"witgen\0") }.unwrap();
Ok(WitgenFunction {
Ok(CompiledFunction {
function: *witgen_fun,
_library: library,
})
Expand Down Expand Up @@ -660,7 +657,7 @@ mod tests {
column_count: usize,
known_inputs: &[Variable],
effects: &[Effect<GoldilocksField, Variable>],
) -> Result<WitgenFunction<GoldilocksField>, String> {
) -> Result<CompiledFunction<GoldilocksField>, String> {
super::compile_effects(
&NoDefinitions,
ColumnLayout {
Expand Down
70 changes: 52 additions & 18 deletions executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use crate::witgen::{

use super::{
block_machine_processor::BlockMachineProcessor,
compiler::{compile_effects, WitgenFunction},
compiler::{compile_effects, CompiledFunction},
interpreter::EffectsInterpreter,
variable::Variable,
witgen_inference::CanProcessCall,
};
Expand All @@ -44,8 +45,36 @@ pub struct FunctionCache<'a, T: FieldElement> {
parts: MachineParts<'a, T>,
}

enum WitgenFunction<T: FieldElement> {
Compiled(CompiledFunction<T>),
Interpreted(EffectsInterpreter<T>),
}

impl<T: FieldElement> WitgenFunction<T> {
/// Call the witgen function to fill the data and "known" tables
/// given a slice of parameters.
/// The `row_offset` is the index inside `data` of the row considered to be "row zero".
/// This function always succeeds (unless it panics).
pub fn call<Q: QueryCallback<T>>(
&self,
fixed_data: &FixedData<'_, T>,
mutable_state: &MutableState<'_, T, Q>,
params: &mut [LookupCell<T>],
data: CompactDataRef<'_, T>,
) {
match self {
WitgenFunction::Compiled(compiled_function) => {
compiled_function.call(fixed_data, mutable_state, params, data);
}
WitgenFunction::Interpreted(interpreter) => {
interpreter.call::<Q>(fixed_data, mutable_state, params, data)
}
}
}
}

pub struct CacheEntry<T: FieldElement> {
pub function: WitgenFunction<T>,
function: WitgenFunction<T>,
pub range_constraints: Vec<RangeConstraint<T>>,
}

Expand Down Expand Up @@ -107,14 +136,11 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
) -> &Option<CacheEntry<T>> {
if !self.witgen_functions.contains_key(cache_key) {
record_start("Auto-witgen code derivation");
let f = match T::known_field() {
// Currently, we only support the Goldilocks fields
Some(KnownField::GoldilocksField) => {
self.compile_witgen_function(can_process, cache_key)
}
_ => None,
};

let interpreted = !matches!(T::known_field(), Some(KnownField::GoldilocksField));
let f = self.compile_witgen_function(can_process, cache_key, interpreted);
assert!(self.witgen_functions.insert(cache_key.clone(), f).is_none());

record_end("Auto-witgen code derivation");
}
self.witgen_functions.get(cache_key).unwrap()
Expand All @@ -124,6 +150,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
&self,
can_process: impl CanProcessCall<T>,
cache_key: &CacheKey<T>,
interpreted: bool,
) -> Option<CacheEntry<T>> {
log::debug!(
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}",
Expand Down Expand Up @@ -187,15 +214,22 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
.filter_map(|(i, b)| if b { Some(Variable::Param(i)) } else { None })
.collect::<Vec<_>>();

log::trace!("Compiling effects...");
let function = compile_effects(
self.fixed_data.analyzed,
self.column_layout.clone(),
&known_inputs,
&code,
prover_functions,
)
.unwrap();
let function = if interpreted {
log::trace!("Building effects interpreter...");
WitgenFunction::Interpreted(EffectsInterpreter::try_new(&known_inputs, &code)?)
} else {
log::trace!("Compiling effects...");
WitgenFunction::Compiled(
compile_effects(
self.fixed_data.analyzed,
self.column_layout.clone(),
&known_inputs,
&code,
prover_functions,
)
.unwrap(),
)
};
log::trace!("Compilation done.");

Some(CacheEntry {
Expand Down
Loading
Loading