Skip to content
Merged
19 changes: 9 additions & 10 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;
use bumpalo::Bump;
use databend_common_exception::Result;

use super::group_hash_columns;
use super::group_hash_entries;
use super::hash_index::AdapterImpl;
use super::hash_index::HashIndex;
use super::partitioned_payload::PartitionedPayload;
Expand All @@ -29,6 +29,7 @@ use super::probe_state::ProbeState;
use super::Entry;
use super::HashTableConfig;
use super::Payload;
use super::BATCH_SIZE;
use super::LOAD_FACTOR;
use super::MAX_PAGE_SIZE;
use crate::types::DataType;
Expand All @@ -37,8 +38,6 @@ use crate::BlockEntry;
use crate::ColumnBuilder;
use crate::ProjectedBlock;

const BATCH_ADD_SIZE: usize = 2048;

pub struct AggregateHashTable {
pub payload: PartitionedPayload,
// use for append rows directly during deserialize
Expand Down Expand Up @@ -127,12 +126,12 @@ impl AggregateHashTable {
agg_states: ProjectedBlock,
row_count: usize,
) -> Result<usize> {
if row_count <= BATCH_ADD_SIZE {
if row_count <= BATCH_SIZE {
self.add_groups_inner(state, group_columns, params, agg_states, row_count)
} else {
let mut new_count = 0;
for start in (0..row_count).step_by(BATCH_ADD_SIZE) {
let end = (start + BATCH_ADD_SIZE).min(row_count);
for start in (0..row_count).step_by(BATCH_SIZE) {
let end = (start + BATCH_SIZE).min(row_count);
let step_group_columns = group_columns
.iter()
.map(|entry| entry.slice(start..end))
Expand Down Expand Up @@ -186,11 +185,11 @@ impl AggregateHashTable {
}

state.row_count = row_count;
group_hash_columns(group_columns, &mut state.group_hashes);
group_hash_entries(group_columns, &mut state.group_hashes[..row_count]);

let new_group_count = if self.direct_append {
for idx in 0..row_count {
state.empty_vector[idx] = idx;
for i in 0..row_count {
state.empty_vector[i] = i.into();
}
self.payload.append_rows(state, row_count, group_columns);
row_count
Expand Down Expand Up @@ -230,7 +229,7 @@ impl AggregateHashTable {

if self.config.partial_agg {
// check size
if self.hash_index.count + BATCH_ADD_SIZE > self.hash_index.resize_threshold()
if self.hash_index.count + BATCH_SIZE > self.hash_index.resize_threshold()
&& self.hash_index.capacity >= self.config.max_partial_capacity
{
self.clear_ht();
Expand Down
127 changes: 90 additions & 37 deletions src/query/expression/src/aggregate/group_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,13 @@ use databend_common_column::types::Index;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;

use crate::types::i256;
use crate::types::number::Number;
use crate::types::AccessType;
use crate::types::AnyType;
use crate::types::BinaryColumn;
use crate::types::BinaryType;
use crate::types::BitmapType;
use crate::types::BooleanType;
use crate::types::DataType;
use crate::types::DateType;
use crate::types::DecimalColumn;
use crate::types::DecimalDataKind;
use crate::types::DecimalScalar;
use crate::types::DecimalView;
use crate::types::GeographyColumn;
use crate::types::GeographyType;
use crate::types::GeometryType;
use crate::types::NullableColumn;
use crate::types::NumberColumn;
use crate::types::NumberDataType;
use crate::types::NumberScalar;
use crate::types::NumberType;
use crate::types::OpaqueScalarRef;
use crate::types::StringColumn;
use crate::types::StringType;
use crate::types::TimestampType;
use crate::types::ValueType;
use crate::types::VariantType;
use crate::types::decimal::Decimal;
use crate::types::*;
use crate::visitor::ValueVisitor;
use crate::with_decimal_mapped_type;
use crate::with_number_mapped_type;
use crate::with_number_type;
use crate::BlockEntry;
use crate::Column;
use crate::ProjectedBlock;
use crate::Scalar;
Expand All @@ -59,23 +34,101 @@ use crate::Value;

const NULL_HASH_VAL: u64 = 0xd1cefa08eb382d69;

pub fn group_hash_columns(cols: ProjectedBlock, values: &mut [u64]) {
debug_assert!(!cols.is_empty());
for (i, entry) in cols.iter().enumerate() {
if i == 0 {
combine_group_hash_column::<true>(&entry.to_column(), values);
} else {
combine_group_hash_column::<false>(&entry.to_column(), values);
pub fn group_hash_entries(entries: ProjectedBlock, values: &mut [u64]) {
debug_assert!(!entries.is_empty());
for (i, entry) in entries.iter().enumerate() {
debug_assert_eq!(entry.len(), values.len());
match entry {
BlockEntry::Const(scalar, data_type, _) => {
if i == 0 {
combine_group_hash_const::<true>(scalar, data_type, values);
} else {
combine_group_hash_const::<false>(scalar, data_type, values);
}
}
BlockEntry::Column(column) => {
if i == 0 {
combine_group_hash_column::<true>(column, values);
} else {
combine_group_hash_column::<false>(column, values);
}
}
}
}
}

pub fn combine_group_hash_column<const IS_FIRST: bool>(c: &Column, values: &mut [u64]) {
fn combine_group_hash_column<const IS_FIRST: bool>(c: &Column, values: &mut [u64]) {
HashVisitor::<IS_FIRST> { values }
.visit_column(c.clone())
.unwrap()
}

fn combine_group_hash_const<const IS_FIRST: bool>(
scalar: &Scalar,
data_type: &DataType,
values: &mut [u64],
) {
match data_type {
DataType::Null | DataType::EmptyArray | DataType::EmptyMap => {}
DataType::Nullable(inner) => {
if scalar.is_null() {
apply_const_hash::<IS_FIRST>(values, NULL_HASH_VAL);
} else {
combine_group_hash_const_nonnull::<IS_FIRST>(scalar, inner, values);
}
}
_ => combine_group_hash_const_nonnull::<IS_FIRST>(scalar, data_type, values),
}
}

fn combine_group_hash_const_nonnull<const IS_FIRST: bool>(
scalar: &Scalar,
_data_type: &DataType,
values: &mut [u64],
) {
let hash = match scalar {
Scalar::Null => unreachable!(),
Scalar::EmptyArray | Scalar::EmptyMap => return,
Scalar::Number(v) => with_number_type!(|NUM_TYPE| match v {
NumberScalar::NUM_TYPE(value) => value.agg_hash(),
}),
Scalar::Decimal(v) => {
with_decimal_mapped_type!(|F| match v {
DecimalScalar::F(v, size) => {
with_decimal_mapped_type!(|T| match size.data_kind() {
DecimalDataKind::T => {
v.as_decimal::<T>().agg_hash()
}
})
}
})
}
Scalar::Timestamp(value) => value.agg_hash(),
Scalar::Date(value) => value.agg_hash(),
Scalar::Boolean(value) => value.agg_hash(),
Scalar::String(value) => value.as_bytes().agg_hash(),
Scalar::Binary(value)
| Scalar::Bitmap(value)
| Scalar::Variant(value)
| Scalar::Geometry(value) => value.agg_hash(),
Scalar::Geography(value) => value.0.agg_hash(),
_ => scalar.as_ref().agg_hash(),
};
apply_const_hash::<IS_FIRST>(values, hash);
}

fn apply_const_hash<const IS_FIRST: bool>(values: &mut [u64], hash: u64) {
if IS_FIRST {
for val in values.iter_mut() {
*val = hash;
}
} else {
for val in values.iter_mut() {
*val = merge_hash(*val, hash);
}
}
}

struct HashVisitor<'a, const IS_FIRST: bool> {
values: &'a mut [u64],
}
Expand All @@ -101,7 +154,7 @@ impl<const IS_FIRST: bool> ValueVisitor for HashVisitor<'_, IS_FIRST> {
fn visit_any_number(&mut self, column: NumberColumn) -> Result<()> {
with_number_mapped_type!(|NUM_TYPE| match column.data_type() {
NumberDataType::NUM_TYPE => {
let c = NUM_TYPE::try_downcast_column(&column).unwrap();
let c = <NUM_TYPE as Number>::try_downcast_column(&column).unwrap();
self.combine_group_hash_type_column::<NumberType<NUM_TYPE>>(&c)
}
});
Expand Down
63 changes: 37 additions & 26 deletions src/query/expression/src/aggregate/hash_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Debug;

use super::payload_row::CompareState;
use super::PartitionedPayload;
use super::ProbeState;
use super::RowPtr;
Expand Down Expand Up @@ -94,7 +97,7 @@ const SALT_MASK: u64 = 0xFFFF000000000000;
const POINTER_MASK: u64 = 0x0000FFFFFFFFFFFF;

// The high 16 bits are the salt, the low 48 bits are the pointer address
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub(super) struct Entry(pub(super) u64);

impl Entry {
Expand Down Expand Up @@ -133,6 +136,15 @@ impl Entry {
}
}

impl Debug for Entry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Entry")
.field(&self.get_salt())
.field(&self.get_pointer())
.finish()
}
}

pub(super) trait TableAdapter {
fn append_rows(&mut self, state: &mut ProbeState, new_entry_count: usize);

Expand All @@ -152,16 +164,10 @@ impl HashIndex {
mut adapter: impl TableAdapter,
) -> usize {
for (i, row) in state.no_match_vector[..row_count].iter_mut().enumerate() {
*row = i;
*row = i.into();
state.slots[i] = self.init_slot(state.group_hashes[i]);
}

let mut slots = state.get_temp();
slots.extend(
state.group_hashes[..row_count]
.iter()
.map(|hash| self.init_slot(*hash)),
);

let mut new_group_count = 0;
let mut remaining_entries = row_count;

Expand All @@ -172,11 +178,11 @@ impl HashIndex {

// 1. inject new_group_count, new_entry_count, need_compare_count, no_match_count
for row in state.no_match_vector[..remaining_entries].iter().copied() {
let slot = &mut slots[row];
let hash = state.group_hashes[row];

let slot = &mut state.slots[row];
let is_new;
(*slot, is_new) = self.find_or_insert(*slot, Entry::hash_to_salt(hash));

let salt = Entry::hash_to_salt(state.group_hashes[row]);
(*slot, is_new) = self.find_or_insert(*slot, salt);

if is_new {
state.empty_vector[new_entry_count] = row;
Expand All @@ -194,7 +200,7 @@ impl HashIndex {
adapter.append_rows(state, new_entry_count);

for row in state.empty_vector[..new_entry_count].iter().copied() {
let entry = self.mut_entry(slots[row]);
let entry = self.mut_entry(state.slots[row]);
entry.set_pointer(state.addresses[row]);
debug_assert_eq!(entry.get_pointer(), state.addresses[row]);
}
Expand All @@ -206,7 +212,7 @@ impl HashIndex {
.iter()
.copied()
{
let entry = self.mut_entry(slots[row]);
let entry = self.mut_entry(state.slots[row]);

debug_assert!(entry.is_occupied());
debug_assert_eq!(entry.get_salt(), (state.group_hashes[row] >> 48) as u16);
Expand All @@ -219,7 +225,7 @@ impl HashIndex {

// 5. Linear probing, just increase iter_times
for row in state.no_match_vector[..no_match_count].iter().copied() {
let slot = &mut slots[row];
let slot = &mut state.slots[row];
*slot += 1;
if *slot >= self.capacity {
*slot = 0;
Expand All @@ -228,7 +234,6 @@ impl HashIndex {
remaining_entries = no_match_count;
}

state.save_temp(slots);
self.count += new_group_count;

new_group_count
Expand All @@ -251,7 +256,13 @@ impl<'a> TableAdapter for AdapterImpl<'a> {
need_compare_count: usize,
no_match_count: usize,
) -> usize {
state.row_match_columns(
// todo: compare hash first if NECESSARY
CompareState {
address: &state.addresses,
compare: &mut state.group_compare_vector,
no_matched: &mut state.no_match_vector,
}
.row_match_entries(
self.group_columns,
&self.payload.row_layout,
(need_compare_count, no_match_count),
Expand Down Expand Up @@ -284,8 +295,10 @@ mod tests {
}

fn init_state(&self) -> ProbeState {
let mut state = ProbeState::default();
state.row_count = self.incoming.len();
let mut state = ProbeState {
row_count: self.incoming.len(),
..Default::default()
};

for (i, (_, hash)) in self.incoming.iter().enumerate() {
state.group_hashes[i] = *hash
Expand Down Expand Up @@ -323,12 +336,12 @@ mod tests {

impl TableAdapter for &mut TestTableAdapter {
fn append_rows(&mut self, state: &mut ProbeState, new_entry_count: usize) {
for row in state.empty_vector[..new_entry_count].iter().copied() {
let (key, hash) = self.incoming[row];
for row in state.empty_vector[..new_entry_count].iter() {
let (key, hash) = self.incoming[*row];
let value = key + 20;

self.payload.push((key, hash, value));
state.addresses[row] = self.get_row_ptr(true, row);
state.addresses[*row] = self.get_row_ptr(true, row.to_usize());
}
}

Expand All @@ -344,9 +357,7 @@ mod tests {
{
let incoming = self.incoming[row];

let row_ptr = state.addresses[row];

let (key, hash, _) = self.get_payload(row_ptr);
let (key, hash, _) = self.get_payload(state.addresses[row]);

const POINTER_MASK: u64 = 0x0000FFFFFFFFFFFF;
assert_eq!(incoming.1 | POINTER_MASK, hash | POINTER_MASK);
Expand Down
5 changes: 3 additions & 2 deletions src/query/expression/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ use hash_index::Entry;
pub use partitioned_payload::*;
pub use payload::*;
pub use payload_flush::*;
pub use probe_state::*;
use row_ptr::RowPtr;
pub use probe_state::ProbeState;
use probe_state::*;
use row_ptr::*;

// A batch size to probe, flush, repartition, etc.
pub(crate) const BATCH_SIZE: usize = 2048;
Expand Down
Loading
Loading