Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions vortex-array/benches/aggregate_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,36 @@ fn i32_clustered_nulls_input() -> ArrayRef {
)
}

fn f64_all_valid_input() -> ArrayRef {
let group_sizes = random_group_sizes();
let element_count = total_element_count(&group_sizes);
let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED);
let values: Buffer<f64> = (0..element_count)
.map(|_| rng.random_range(-1000.0..1000.0))
.collect();
contiguous_list_view(
PrimitiveArray::new(values, Validity::NonNullable).into_array(),
&group_sizes,
)
}

fn f64_clustered_nulls_input() -> ArrayRef {
let group_sizes = random_group_sizes();
let element_count = total_element_count(&group_sizes);
let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED);
let values = (0..element_count).map(|i| {
if (i / 16) % 8 == 0 {
None
} else {
Some(rng.random_range(-1000.0f64..1000.0))
}
});
contiguous_list_view(
PrimitiveArray::from_option_iter(values).into_array(),
&group_sizes,
)
}

fn varbinview_input() -> ArrayRef {
let group_sizes = random_group_sizes();
let element_count = total_element_count(&group_sizes);
Expand Down Expand Up @@ -144,6 +174,22 @@ fn sum_i32_clustered_nulls(bencher: Bencher) {
.bench_refs(|input| grouped_accumulator(input, Sum));
}

#[divan::bench]
fn sum_f64_all_valid(bencher: Bencher) {
let input = f64_all_valid_input();
bencher
.with_inputs(|| &input)
.bench_refs(|input| grouped_accumulator(input, Sum));
}

#[divan::bench]
fn sum_f64_clustered_nulls(bencher: Bencher) {
let input = f64_clustered_nulls_input();
bencher
.with_inputs(|| &input)
.bench_refs(|input| grouped_accumulator(input, Sum));
}

#[divan::bench]
fn count_i32_clustered_nulls(bencher: Bencher) {
let input = i32_clustered_nulls_input();
Expand Down
36 changes: 36 additions & 0 deletions vortex-array/benches/aggregate_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,42 @@ fn sum_i64(bencher: Bencher) {
.bench_refs(|(a, ctx)| a.statistics().compute_as::<i64>(Stat::Sum, ctx));
}

#[divan::bench]
fn sum_f64(bencher: Bencher) {
let mut rng = StdRng::seed_from_u64(6);
let data: Vec<f64> = (0..N).map(|_| rng.random_range(-1000.0..1000.0)).collect();
bencher
.with_inputs(|| {
(
PrimitiveArray::from_iter(data.iter().copied()).into_array(),
SESSION.create_execution_ctx(),
)
})
.bench_refs(|(a, ctx)| a.statistics().compute_as::<f64>(Stat::Sum, ctx));
}

#[divan::bench]
fn sum_f64_nulls_clustered(bencher: Bencher) {
let mut rng = StdRng::seed_from_u64(7);
let data: Vec<Option<f64>> = (0..N)
.map(|i| {
if (i / 64) % 10 == 0 {
None
} else {
Some(rng.random_range(-1000.0..1000.0))
}
})
.collect();
bencher
.with_inputs(|| {
(
PrimitiveArray::from_option_iter(data.iter().copied()).into_array(),
SESSION.create_execution_ctx(),
)
})
.bench_refs(|(a, ctx)| a.statistics().compute_as::<f64>(Stat::Sum, ctx));
}

// Clustered nulls: long runs of valid values broken up by occasional null blocks. This is the
// case the run-based valid path is expected to accelerate.
#[divan::bench]
Expand Down
85 changes: 56 additions & 29 deletions vortex-array/src/aggregate_fn/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,17 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
);

// 0. Legacy stats bridge: if this aggregate is still cached under a legacy Stat slot,
// consume that exact stat before kernel dispatch or decode.
// consume that exact stat before kernel dispatch or decode. When the stat dtype is
// incompatible with the partial dtype (e.g. float Sum partials carry a compensation
// term the f64 stat lacks), fall through to regular dispatch, where the vtable may
// still consume the stat itself (see `Sum::try_accumulate`).
if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
&& let Precision::Exact(partial) = batch.statistics().get(stat)
&& partial.dtype().eq_ignore_nullability(&self.partial_dtype)
{
let partial = if partial.dtype() == &self.partial_dtype {
partial
} else {
vortex_ensure!(
partial.dtype().eq_ignore_nullability(&self.partial_dtype),
"Aggregate {} read legacy stat {} with dtype {}, expected {}",
self.aggregate_fn,
stat,
partial.dtype(),
self.partial_dtype,
);
partial.cast(&self.partial_dtype)?
};
self.vtable.combine_partials(&mut self.partial, partial)?;
Expand Down Expand Up @@ -332,10 +328,38 @@ mod tests {
_batch: &ArrayRef,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<Scalar>> {
Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
Ok(Some(sum_partial(42.0)))
}
}

/// Build a float Sum partial scalar `{sum: value, compensation: 0.0}`.
fn sum_partial(value: f64) -> Scalar {
let dtype = Sum
.partial_dtype(
&EmptyOptions,
&DType::Primitive(PType::F64, Nullability::NonNullable),
)
.expect("sum supports f64");
Scalar::struct_(
dtype,
vec![
Scalar::primitive(value, Nullability::Nullable),
Scalar::primitive(0.0f64, Nullability::Nullable),
],
)
}

/// Read the `sum` value out of a mean partial `{sum: {sum, compensation}, count}`.
fn partial_sum_value(partial: &Scalar) -> Option<f64> {
partial
.as_struct()
.field("sum")?
.as_struct()
.field("sum")?
.as_primitive()
.as_::<f64>()
}

fn fresh_session() -> VortexSession {
VortexSession::empty().with::<ArraySession>()
}
Expand All @@ -357,7 +381,7 @@ mod tests {

fn sentinel_partial() -> Scalar {
let acc = mean_f64_accumulator().expect("build accumulator");
let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
let sum = sum_partial(42.0);
let count = Scalar::primitive(1u64, Nullability::NonNullable);
Scalar::struct_(acc.partial_dtype, vec![sum, count])
}
Expand All @@ -377,13 +401,14 @@ mod tests {
acc.accumulate(&dict_of_seven(), &mut ctx)?;
let partial = acc.flush()?;

let s = partial.as_struct();
assert_eq!(
s.field("sum").unwrap().as_primitive().as_::<f64>(),
Some(42.0)
);
assert_eq!(partial_sum_value(&partial), Some(42.0));
assert_eq!(
s.field("count").unwrap().as_primitive().as_::<u64>(),
partial
.as_struct()
.field("count")
.unwrap()
.as_primitive()
.as_::<u64>(),
Some(1)
);
Ok(())
Expand All @@ -404,13 +429,14 @@ mod tests {
acc.accumulate(&dict_of_seven(), &mut ctx)?;
let partial = acc.flush()?;

let s = partial.as_struct();
assert_eq!(partial_sum_value(&partial), Some(7.0));
assert_eq!(
s.field("sum").unwrap().as_primitive().as_::<f64>(),
Some(7.0)
);
assert_eq!(
s.field("count").unwrap().as_primitive().as_::<u64>(),
partial
.as_struct()
.field("count")
.unwrap()
.as_primitive()
.as_::<u64>(),
Some(1)
);
Ok(())
Expand All @@ -432,16 +458,17 @@ mod tests {
acc.accumulate(&dict_of_seven(), &mut ctx)?;
let partial = acc.flush()?;

let s = partial.as_struct();
// `Sum` child returned the sentinel 42.0 — proves the (Dict, Sum) kernel fired
// via `Combined<Mean>`'s fan-out. `Count`'s native `try_accumulate` reads the
// batch's valid_count, so count is the real 1.
assert_eq!(partial_sum_value(&partial), Some(42.0));
assert_eq!(
s.field("sum").unwrap().as_primitive().as_::<f64>(),
Some(42.0)
);
assert_eq!(
s.field("count").unwrap().as_primitive().as_::<u64>(),
partial
.as_struct()
.field("count")
.unwrap()
.as_primitive()
.as_::<u64>(),
Some(1)
);
Ok(())
Expand Down
Loading
Loading