Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dynamic bus (runtime witgen) #2539

Open
wants to merge 4 commits into
base: remove-connection
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions executor/src/witgen/data_structures/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,15 @@ fn convert_phantom_bus_interaction<T: FieldElement>(
},
_ => (false, bus_interaction.multiplicity.clone()),
};
let bus_id = match bus_interaction.bus_id {
AlgebraicExpression::Number(id) => id,
// TODO: Relax this for sends when implementing dynamic sends
_ => panic!("Expected first payload entry to be a static ID"),
};
let selected_payload = SelectedExpressions {
selector: bus_interaction.latch.clone(),
expressions: bus_interaction.payload.0.clone(),
};
if is_receive {
let bus_id = match bus_interaction.bus_id {
AlgebraicExpression::Number(id) => id,
_ => panic!("Expected first payload entry of a receive to be a static ID"),
};
IdentityOrReceive::Receive(BusReceive {
bus_id,
multiplicity: Some(multiplicity),
Expand All @@ -314,7 +313,7 @@ fn convert_phantom_bus_interaction<T: FieldElement>(
assert_eq!(multiplicity, bus_interaction.latch);
IdentityOrReceive::Identity(Identity::BusSend(BusSend {
identity_id: bus_interaction.id,
bus_id: AlgebraicExpression::Number(bus_id),
bus_id: bus_interaction.bus_id.clone(),
selected_payload,
}))
}
Expand Down
2 changes: 2 additions & 0 deletions executor/src/witgen/eval_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub enum IncompleteCause<K = usize> {
NonConstantQueryMatchScrutinee,
/// Query element is not constant.
NonConstantQueryElement,
/// Bus ID is not constant.
NonConstantBusID,
/// A required argument was not provided
NonConstantRequiredArgument(&'static str),
/// The left selector in a lookup is not constant. Example: `x * {1} in [{1}]` where `x` is not constant.
Expand Down
6 changes: 5 additions & 1 deletion executor/src/witgen/global_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,11 @@ fn propagate_constraints<T: FieldElement>(
}
}
Identity::BusSend(send) => {
let receive = send.try_match_static(bus_receives).unwrap();
let receive = match send.try_match_static(bus_receives) {
Some(r) => r,
// For dynamic sends, we can only propagate constraints at runtime.
None => return false,
};
if !send.selected_payload.selector.is_one() {
return false;
}
Expand Down
26 changes: 15 additions & 11 deletions executor/src/witgen/identity_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use powdr_number::FieldElement;
use crate::witgen::data_structures::mutable_state::MutableState;
use crate::witgen::{global_constraints::CombinedRangeConstraintSet, EvalError};

use super::data_structures::identity::Identity;
use super::data_structures::identity::{BusSend, Identity};
use super::{
affine_expression::AlgebraicVariable, processor::OuterQuery, rows::RowPair, EvalResult,
EvalValue, IncompleteCause, QueryCallback,
Expand Down Expand Up @@ -37,11 +37,7 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
) -> EvalResult<'a, T> {
let result = match identity {
Identity::Polynomial(identity) => self.process_polynomial_identity(identity, rows),
Identity::BusSend(bus_interaction) => self.process_machine_call(
bus_interaction.bus_id().unwrap(),
&bus_interaction.selected_payload,
rows,
),
Identity::BusSend(bus_send) => self.process_machine_call(bus_send, rows),
Identity::Connect(..) => {
// TODO this is not the right cause.
Ok(EvalValue::incomplete(IncompleteCause::SolvingFailed))
Expand All @@ -67,15 +63,15 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,

fn process_machine_call(
&mut self,
bus_id: T,
left: &'a powdr_ast::analyzed::SelectedExpressions<T>,
bus_send: &'a BusSend<T>,
rows: &RowPair<'_, 'a, T>,
) -> EvalResult<'a, T> {
if let Some(status) = self.handle_left_selector(&left.selector, rows) {
if let Some(status) = self.handle_left_selector(&bus_send.selected_payload.selector, rows) {
return Ok(status);
}

let left = match left
let arguments = match bus_send
.selected_payload
.expressions
.iter()
.map(|e| rows.evaluate(e))
Expand All @@ -85,7 +81,15 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
Err(incomplete_cause) => return Ok(EvalValue::incomplete(incomplete_cause)),
};

self.mutable_state.call(bus_id, &left, rows)
let bus_id = match rows.evaluate(&bus_send.bus_id) {
Ok(bus_id) => match bus_id.constant_value() {
Some(bus_id) => bus_id,
None => return Ok(EvalValue::incomplete(IncompleteCause::NonConstantBusID)),
},
Err(incomplete_cause) => return Ok(EvalValue::incomplete(incomplete_cause)),
};

self.mutable_state.call(bus_id, &arguments, rows)
}

/// Handles the lookup that connects the current machine to the calling machine.
Expand Down
103 changes: 73 additions & 30 deletions executor/src/witgen/multiplicity_column_generator.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use std::collections::{BTreeMap, HashMap};
use std::{
collections::{BTreeMap, HashMap},
iter::once,
};

use itertools::Itertools;
use powdr_ast::{
analyzed::{AlgebraicExpression, PolyID, PolynomialType, SelectedExpressions},
parsed::visitor::AllChildren,
Expand Down Expand Up @@ -35,9 +39,6 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
) -> HashMap<String, Vec<T>> {
record_start(MULTIPLICITY_WITGEN_NAME);

// A map from multiplicity column ID to the vector of multiplicities.
let mut multiplicity_columns = BTreeMap::new();

let (identities, _) = convert_identities(self.fixed.analyzed);

let all_columns = witness_columns
Expand Down Expand Up @@ -74,8 +75,15 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
bus_receive.has_arbitrary_multiplicity() && bus_receive.multiplicity.is_some()
})
.map(|(bus_id, bus_receive)| {
let (size, rhs_tuples) =
self.get_tuples(&terminal_values, &bus_receive.selected_payload);
let SelectedExpressions {
selector,
expressions,
} = &bus_receive.selected_payload;
let (size, rhs_tuples) = self.get_tuples(
&terminal_values,
selector,
&expressions.iter().collect::<Vec<_>>(),
);

let index = rhs_tuples
.into_iter()
Expand Down Expand Up @@ -104,28 +112,55 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
})
.collect::<BTreeMap<_, _>>();

// A map from multiplicity column ID to the vector of multiplicities.
let mut multiplicity_columns = receive_infos
.values()
.map(|info| (info.multiplicity_column, vec![0; info.size]))
.collect::<BTreeMap<_, _>>();

// Increment multiplicities for all bus sends.
for (bus_send, bus_receive) in identities.iter().filter_map(|i| match i {
Identity::BusSend(bus_send) => receive_infos
.get(&bus_send.bus_id().unwrap())
.map(|bus_receive| (bus_send, bus_receive)),
let bus_sends = identities.iter().filter_map(|i| match i {
Identity::BusSend(bus_send) => match bus_send.bus_id() {
// As a performance optimization, already filter out sends with a static
// bus ID for which we know we don't need to track multiplicities.
Some(bus_id) => receive_infos.get(&bus_id).map(|_| bus_send),
Comment on lines +124 to +125
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Keccak-via-RISC-V, this reduces multiplicity witgen from ~520ms to ~450ms (2.6% of witgen time). Before this PR, it's ~320ms. I suppose looking up the receive for each row below causes this regression. We could do more shenanigans to speed this up again in the static case, but I don't think it's worth it.

// For dynamic sends, this optimization is not possible.
None => Some(bus_send),
},
_ => None,
}) {
let (_, lhs_tuples) = self.get_tuples(&terminal_values, &bus_send.selected_payload);
});

let multiplicities = multiplicity_columns
.entry(bus_receive.multiplicity_column)
.or_insert_with(|| vec![0; bus_receive.size]);
assert_eq!(multiplicities.len(), bus_receive.size);
for bus_send in bus_sends {
let SelectedExpressions {
selector,
expressions,
} = &bus_send.selected_payload;

// We need to evaluate both the bus_id (to know the run-time receive) and the expressions
let bus_id_and_expressions = once(&bus_send.bus_id)
.chain(expressions.iter())
.collect::<Vec<_>>();

let (_, bus_id_and_expressions) =
self.get_tuples(&terminal_values, selector, &bus_id_and_expressions);

// Looking up the index is slow, so we do it in parallel.
let indices = lhs_tuples
let columns_and_indices = bus_id_and_expressions
.into_par_iter()
.map(|(_, tuple)| bus_receive.index[&tuple])
.filter_map(|(_, bus_id_and_expressions)| {
receive_infos
.get(&bus_id_and_expressions[0])
.map(|receive_info| {
(
receive_info.multiplicity_column,
receive_info.index[&bus_id_and_expressions[1..]],
)
})
})
.collect::<Vec<_>>();

for index in indices {
multiplicities[index] += 1;
for (multiplicity_column, index) in columns_and_indices {
multiplicity_columns.get_mut(&multiplicity_column).unwrap()[index] += 1;
}
}

Expand All @@ -150,12 +185,13 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
fn get_tuples(
&self,
terminal_values: &OwnedTerminalValues<T>,
selected_expressions: &SelectedExpressions<T>,
selector: &AlgebraicExpression<T>,
expressions: &[&AlgebraicExpression<T>],
) -> (usize, Vec<(usize, Vec<T>)>) {
let machine_size = selected_expressions
.expressions
let machine_size = expressions
.iter()
.flat_map(|expr| expr.all_children())
.flat_map(|e| e.all_children())
.chain(selector.all_children())
.filter_map(|expr| match expr {
AlgebraicExpression::Reference(ref r) => match r.poly_id.ptype {
PolynomialType::Committed | PolynomialType::Constant => {
Expand All @@ -169,7 +205,12 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
// But in practice, either the machine has a (smaller) witness column, or
// it's a fixed lookup, so there is only one size.
.min()
.unwrap_or_else(|| panic!("No column references found: {selected_expressions}"));
.unwrap_or_else(|| {
panic!(
"No column references found: {selector} $ [{}]",
expressions.iter().map(ToString::to_string).join(", ")
)
});

let tuples = (0..machine_size)
.into_par_iter()
Expand All @@ -178,14 +219,16 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
terminal_values.row(row),
&self.fixed.intermediate_definitions,
);
let result = evaluator.evaluate(&selected_expressions.selector);
let selector = evaluator.evaluate(selector);

assert!(result.is_zero() || result.is_one(), "Non-binary selector");
result.is_one().then(|| {
assert!(
selector.is_zero() || selector.is_one(),
"Non-binary selector"
);
selector.is_one().then(|| {
(
row,
selected_expressions
.expressions
expressions
.iter()
.map(|expression| evaluator.evaluate(expression))
.collect::<Vec<_>>(),
Expand Down
9 changes: 6 additions & 3 deletions executor/src/witgen/vm_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,12 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'c, T, Q> {
.filter_map(|(index, (ident, _))| match ident {
Identity::BusSend(send) => send
.try_match_static(&self.fixed_data.bus_receives)
.unwrap()
.has_arbitrary_multiplicity()
.then_some((index, &send.selected_payload)),
// We assume that the PC lookup is static.
.and_then(|receive| {
receive
.has_arbitrary_multiplicity()
.then_some((index, &send.selected_payload))
}),
_ => None,
})
.max_by_key(|(_, left)| left.expressions.len())
Expand Down
3 changes: 1 addition & 2 deletions pipeline/tests/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fn empty() {

#[test]
// TODO: https://github.com/powdr-labs/powdr/issues/2292
#[should_panic = "No column references found: []"]
#[should_panic = "No column references found: 1 $ []"]
fn single_operation() {
let f = "asm/single_operation.asm";
regular_test_all_fields(f, &[]);
Expand Down Expand Up @@ -251,7 +251,6 @@ fn static_bus_multi() {
}

#[test]
#[should_panic = "Expected first payload entry to be a static ID"]
fn dynamic_bus() {
// Witgen does not currently support this.
let f = "asm/dynamic_bus.asm";
Expand Down
Loading