diff --git a/timely/src/dataflow/operators/aggregation/aggregate.rs b/timely/src/dataflow/operators/aggregation/aggregate.rs index 55f7d50a47..6443a250e7 100644 --- a/timely/src/dataflow/operators/aggregation/aggregate.rs +++ b/timely/src/dataflow/operators/aggregation/aggregate.rs @@ -16,7 +16,7 @@ pub trait Aggregate { /// /// The `aggregate` method is implemented for streams of `(K, V)` data, /// and takes functions `fold`, `emit`, and `hash`; used to combine new `V` - /// data with existing `D` state, to produce `R` output from `D` state, and + /// data with existing `D` state, to produce `I` output from `D` state, and /// to route `K` keys, respectively. /// /// Aggregation happens within each time, and results are produced once the @@ -32,8 +32,9 @@ pub trait Aggregate { /// (0..10).to_stream(scope) /// .map(|x| (x % 2, x)) /// .aggregate( + /// i32::default, /// |_key, val, agg| { *agg += val; }, - /// |key, agg: i32| (key, agg), + /// |key, agg| [(key, agg)], /// |key| *key as u64 /// ) /// .inspect(|x| assert!(*x == (0, 20) || *x == (1, 25))); @@ -52,28 +53,36 @@ pub trait Aggregate { /// /// (0..10).to_stream(scope) /// .map(|x| (x % 2, x)) - /// .aggregate::<_,Vec,_,_,_>( + /// .aggregate( + /// Vec::default, /// |_key, val, agg| { agg.push(val); }, - /// |key, agg| (key, agg.len()), + /// |key, agg| [(key, agg.len())], /// |key| *key as u64 /// ) /// .inspect(|x| assert!(*x == (0, 5) || *x == (1, 5))); /// }); /// ``` - fn aggregateR+'static, H: Fn(&K)->u64+'static>( + fn aggregate D+ 'static, F: Fn(&K, V, &mut D)+'static, E: Fn(K, D)->I+'static, H: Fn(&K)->u64+'static>( &self, + make_default: M, fold: F, emit: E, - hash: H) -> Stream where S::Timestamp: Eq; + hash: H) -> Stream + where S::Timestamp: Eq, + I::Item: Data; } impl Aggregate for Stream { - fn aggregateR+'static, H: Fn(&K)->u64+'static>( + fn aggregate D + 'static, F: Fn(&K, V, &mut D)+'static, E: Fn(K, D)->I+'static, H: Fn(&K)->u64+'static>( &self, + make_default: M, fold: F, emit: E, - hash: H) -> Stream where S::Timestamp: Eq { + hash: H) -> Stream + where S::Timestamp: Eq, + I::Item: Data + { let mut aggregates = HashMap::new(); let mut vector = Vec::new(); @@ -84,7 +93,7 @@ impl Aggregate for data.swap(&mut vector); let agg_time = aggregates.entry(time.time().clone()).or_insert_with(HashMap::new); for (key, val) in vector.drain(..) { - let agg = agg_time.entry(key.clone()).or_insert_with(Default::default); + let agg = agg_time.entry(key.clone()).or_insert_with(&make_default); fold(&key, val, agg); } notificator.notify_at(time.retain()); @@ -95,7 +104,7 @@ impl Aggregate for if let Some(aggs) = aggregates.remove(time.time()) { let mut session = output.session(&time); for (key, agg) in aggs { - session.give(emit(key, agg)); + session.give_iterator(emit(key, agg).into_iter()); } } });