diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index c18d959cd7e..c986086825a 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -54,7 +54,9 @@ impl ArrayParentReduceRule for DTPFilterPushDownRule { ) -> VortexResult> { debug_assert_eq!(child_idx, 0); - if !child.seconds().is::() || !child.subseconds().is::() { + if *child.seconds().encoding_id() != Constant::ID + || *child.subseconds().encoding_id() != Constant::ID + { return Ok(None); } diff --git a/encodings/zstd/src/zstd_buffers.rs b/encodings/zstd/src/zstd_buffers.rs index cf7acd8f1b0..1c3758665de 100644 --- a/encodings/zstd/src/zstd_buffers.rs +++ b/encodings/zstd/src/zstd_buffers.rs @@ -53,7 +53,6 @@ impl ZstdBuffers { } pub fn compress(array: &ArrayRef, level: i32) -> VortexResult { - let encoding_id = array.encoding_id(); let metadata = array .metadata()? .ok_or_else(|| vortex_err!("Array does not support serialization"))?; @@ -74,6 +73,7 @@ impl ZstdBuffers { compressed_buffers.push(BufferHandle::new_host(ByteBuffer::from(compressed))); } + let encoding_id = array.encoding_id().clone(); let data = ZstdBuffersData { inner_encoding_id: encoding_id, inner_metadata: metadata, diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 0e198a1cdc0..e0306075424 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -21504,7 +21504,7 @@ pub fn vortex_array::Array::data(&self) -> &::Arra pub fn vortex_array::Array::dtype(&self) -> &vortex_array::dtype::DType -pub fn vortex_array::Array::encoding_id(&self) -> vortex_array::ArrayId +pub fn vortex_array::Array::encoding_id(&self) -> &vortex_array::ArrayId pub fn vortex_array::Array::into_data(self) -> ::ArrayData @@ -21680,7 +21680,7 @@ pub fn vortex_array::ArrayRef::downcast(self) -> vortex pub fn vortex_array::ArrayRef::dtype(&self) -> &vortex_array::dtype::DType -pub fn vortex_array::ArrayRef::encoding_id(&self) -> vortex_array::ArrayId +pub fn vortex_array::ArrayRef::encoding_id(&self) -> &vortex_array::ArrayId pub fn vortex_array::ArrayRef::execute_parent(&self, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -22128,7 +22128,7 @@ pub fn vortex_array::ArrayView<'a, V>::data(&self) -> &'a ::dtype(&self) -> &vortex_array::dtype::DType -pub fn vortex_array::ArrayView<'a, V>::encoding_id(&self) -> vortex_array::ArrayId +pub fn vortex_array::ArrayView<'a, V>::encoding_id(&self) -> &vortex_array::ArrayId pub fn vortex_array::ArrayView<'a, V>::into_owned(self) -> vortex_array::Array diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index 00c80b38221..5b53efa029b 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -114,7 +114,7 @@ impl DynAccumulator for Accumulator { } let kernels_r = kernels.read(); - let batch_id = batch.encoding_id(); + let batch_id = batch.encoding_id().clone(); if let Some(result) = kernels_r .get(&(batch_id.clone(), Some(self.aggregate_fn.id()))) .or_else(|| kernels_r.get(&(batch_id, None))) diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 24f8f0157ec..1fb6e909d4b 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -171,9 +171,10 @@ impl GroupedAccumulator { } let kernels_r = kernels.read(); + let elements_id = elements.encoding_id().clone(); if let Some(result) = kernels_r - .get(&(elements.encoding_id(), Some(self.aggregate_fn.id()))) - .or_else(|| kernels_r.get(&(elements.encoding_id(), None))) + .get(&(elements_id.clone(), Some(self.aggregate_fn.id()))) + .or_else(|| kernels_r.get(&(elements_id, None))) .and_then(|kernel| { // SAFETY: we assume that elements execution is safe let groups = unsafe { @@ -263,9 +264,10 @@ impl GroupedAccumulator { } let kernels_r = kernels.read(); + let elements_id = elements.encoding_id().clone(); if let Some(result) = kernels_r - .get(&(elements.encoding_id(), Some(self.aggregate_fn.id()))) - .or_else(|| kernels_r.get(&(elements.encoding_id(), None))) + .get(&(elements_id.clone(), Some(self.aggregate_fn.id()))) + .or_else(|| kernels_r.get(&(elements_id, None))) .and_then(|kernel| { // SAFETY: we assume that elements execution is safe let groups = unsafe { diff --git a/vortex-array/src/aggregate_fn/fns/is_constant/mod.rs b/vortex-array/src/aggregate_fn/fns/is_constant/mod.rs index c6ae93fce5a..98999c372cd 100644 --- a/vortex-array/src/aggregate_fn/fns/is_constant/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/is_constant/mod.rs @@ -107,7 +107,8 @@ pub fn is_constant(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult() || array.is::() { + let id = array.encoding_id(); + if *id == Constant::ID || *id == Null::ID { array .statistics() .set(Stat::IsConstant, Precision::Exact(true.into())); diff --git a/vortex-array/src/aggregate_fn/fns/is_sorted/mod.rs b/vortex-array/src/aggregate_fn/fns/is_sorted/mod.rs index 8c973b940c9..fa36f8a357e 100644 --- a/vortex-array/src/aggregate_fn/fns/is_sorted/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/is_sorted/mod.rs @@ -88,7 +88,8 @@ fn is_sorted_impl(array: &ArrayRef, strict: bool, ctx: &mut ExecutionCtx) -> Vor } // Constant and null arrays are always sorted, but not strict sorted. - if array.is::() || array.is::() { + let id = array.encoding_id(); + if *id == Constant::ID || *id == Null::ID { let result = !strict; cache_is_sorted(array, strict, result); return Ok(result); diff --git a/vortex-array/src/array/erased.rs b/vortex-array/src/array/erased.rs index cb5d2e1bb86..c2429b9b7a8 100644 --- a/vortex-array/src/array/erased.rs +++ b/vortex-array/src/array/erased.rs @@ -147,8 +147,8 @@ impl ArrayRef { /// Returns the encoding ID of the array. #[inline] - pub fn encoding_id(&self) -> ArrayId { - self.0.encoding_id() + pub fn encoding_id(&self) -> &ArrayId { + self.inner().encoding_id() } /// Performs a constant-time slice of the array. diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index f6ed110e172..7d4b79e43e5 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -68,7 +68,7 @@ pub(crate) trait DynArray: 'static + private::Sealed + Send + Sync + Debug { fn slots(&self) -> &[Option]; /// Returns the encoding ID of the array. - fn encoding_id(&self) -> ArrayId; + fn encoding_id(&self) -> &ArrayId; /// Fetch the scalar at the given index. /// @@ -210,8 +210,8 @@ impl DynArray for ArrayInner { &self.slots } - fn encoding_id(&self) -> ArrayId { - self.vtable.id() + fn encoding_id(&self) -> &ArrayId { + &self.encoding_id } fn scalar_at(&self, this: &ArrayRef, index: usize) -> VortexResult { @@ -367,7 +367,7 @@ impl DynArray for ArrayInner { .is_some_and(|other_inner| { self.len == other.len() && self.dtype == *other.dtype() - && self.vtable.id() == other.encoding_id() + && &self.vtable.id() == other.encoding_id() && self.slots.len() == other_inner.slots.len() && self .slots @@ -442,7 +442,7 @@ impl DynArray for ArrayInner { let stats = this.statistics().to_owned(); let typed = Array::::try_from_array_ref(this) - .map_err(|_| vortex_err!("Failed to downcast array for execute")) + .map_err(|_| vortex_err!("")) .vortex_expect("Failed to downcast array for execute"); let result = V::execute(typed, ctx)?; diff --git a/vortex-array/src/array/typed.rs b/vortex-array/src/array/typed.rs index 0d8239af589..9ecb9cbae1e 100644 --- a/vortex-array/src/array/typed.rs +++ b/vortex-array/src/array/typed.rs @@ -78,6 +78,7 @@ impl TypedArrayRef for ArrayView<'_, V> {} #[doc(hidden)] pub(crate) struct ArrayInner { pub(crate) vtable: V, + pub(crate) encoding_id: ArrayId, pub(crate) dtype: DType, pub(crate) len: usize, pub(crate) data: V::ArrayData, @@ -115,8 +116,10 @@ impl ArrayInner { slots: Vec>, stats: ArrayStats, ) -> Self { + let encoding_id = vtable.id(); Self { vtable, + encoding_id, dtype, len, data, @@ -143,6 +146,7 @@ impl Clone for ArrayInner { fn clone(&self) -> Self { Self { vtable: self.vtable.clone(), + encoding_id: self.encoding_id.clone(), dtype: self.dtype.clone(), len: self.len, data: self.data.clone(), @@ -155,7 +159,7 @@ impl Clone for ArrayInner { impl Debug for ArrayInner { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("ArrayInner") - .field("encoding", &self.vtable.id()) + .field("encoding", &self.encoding_id) .field("dtype", &self.dtype) .field("len", &self.len) .field("inner", &self.data) @@ -265,8 +269,8 @@ impl Array { } /// Returns the encoding ID. - pub fn encoding_id(&self) -> ArrayId { - self.inner.encoding_id() + pub fn encoding_id(&self) -> &ArrayId { + &self.downcast_inner().encoding_id } /// Returns the statistics. diff --git a/vortex-array/src/array/view.rs b/vortex-array/src/array/view.rs index c751429abf8..cb3021549cb 100644 --- a/vortex-array/src/array/view.rs +++ b/vortex-array/src/array/view.rs @@ -61,7 +61,7 @@ impl<'a, V: VTable> ArrayView<'a, V> { self.array.len() == 0 } - pub fn encoding_id(&self) -> ArrayId { + pub fn encoding_id(&self) -> &ArrayId { self.array.encoding_id() } diff --git a/vortex-array/src/canonical.rs b/vortex-array/src/canonical.rs index cdf569a075c..8b24132d307 100644 --- a/vortex-array/src/canonical.rs +++ b/vortex-array/src/canonical.rs @@ -995,17 +995,17 @@ impl Matcher for AnyCanonical { type Match<'a> = CanonicalView<'a>; fn matches(array: &ArrayRef) -> bool { - array.is::() - || array.is::() - || array.is::() - || array.is::() - || array.is::() - || array.is::() - || array.is::() - || array.is::() - || array.is::() - || array.is::() - || array.is::() + let id = array.encoding_id(); + id == &Null::ID + || id == &Bool::ID + || id == &Primitive::ID + || id == &Decimal::ID + || id == &Struct::ID + || id == &ListView::ID + || id == &FixedSizeList::ID + || id == &VarBinView::ID + || id == &Variant::ID + || id == &Extension::ID } fn try_match<'a>(array: &'a ArrayRef) -> Option> { diff --git a/vortex-array/src/normalize.rs b/vortex-array/src/normalize.rs index 9a796e5485c..ee42f6556e5 100644 --- a/vortex-array/src/normalize.rs +++ b/vortex-array/src/normalize.rs @@ -37,7 +37,7 @@ impl ArrayRef { } fn normalize_with_error(&self, allowed: &[Id]) -> VortexResult<()> { - if !allowed.contains(&self.encoding_id()) { + if !allowed.contains(self.encoding_id()) { vortex_bail!(AssertionFailed: "normalize forbids encoding ({})", self.encoding_id()) } diff --git a/vortex-array/src/serde.rs b/vortex-array/src/serde.rs index 1dda5d01a48..4788cdd1f63 100644 --- a/vortex-array/src/serde.rs +++ b/vortex-array/src/serde.rs @@ -192,7 +192,7 @@ impl<'a> ArrayNodeFlatBuffer<'a> { ) -> VortexResult>> { let encoding_idx = self .ctx - .intern(&self.array.encoding_id()) + .intern(self.array.encoding_id()) // TODO(ngates): write_flatbuffer should return a result if this can fail. .ok_or_else(|| { vortex_err!( @@ -358,7 +358,7 @@ impl SerializedArray { ); assert_eq!( decoded.encoding_id(), - encoding_id, + &encoding_id, "Array decoded from {} has incorrect encoding {}", encoding_id, decoded.encoding_id(), diff --git a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs index 1ea06e522ba..8d34471f122 100644 --- a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs +++ b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs @@ -53,14 +53,14 @@ pub struct MaterializedPlan { /// Checks whether the encoding of an array can be fused into a dynamic-dispatch plan. fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool { let id = array.encoding_id(); - if id == ALP::ID { + if *id == ALP::ID { let arr = array.as_::(); return arr.patches().is_none() && arr.dtype().as_ptype() == PType::F32; } - if id == BitPacked::ID { + if *id == BitPacked::ID { return array.as_::().patches().is_none(); } - if id == Dict::ID { + if *id == Dict::ID { let arr = array.as_::(); // As of now the dict dyn dispatch kernel requires // codes and values to have the same byte width. @@ -72,7 +72,7 @@ fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool { _ => false, }; } - if id == RunEnd::ID { + if *id == RunEnd::ID { let arr = array.as_::(); // As of now the run-end dyn dispatch kernel requires // ends and values to have the same byte width. @@ -84,11 +84,11 @@ fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool { _ => false, }; } - id == FoR::ID - || id == ZigZag::ID - || id == Primitive::ID - || id == Slice::ID - || id == Sequence::ID + *id == FoR::ID + || *id == ZigZag::ID + || *id == Primitive::ID + || *id == Slice::ID + || *id == Sequence::ID } /// An unmaterialized stage: a source op, scalar ops, and optional source buffer reference. @@ -361,23 +361,23 @@ impl FusedPlan { let id = array.encoding_id(); - if id == BitPacked::ID { + if *id == BitPacked::ID { self.walk_bitpacked(array) - } else if id == FoR::ID { + } else if *id == FoR::ID { self.walk_for(array, pending_subtrees) - } else if id == ZigZag::ID { + } else if *id == ZigZag::ID { self.walk_zigzag(array, pending_subtrees) - } else if id == ALP::ID { + } else if *id == ALP::ID { self.walk_alp(array, pending_subtrees) - } else if id == Dict::ID { + } else if *id == Dict::ID { self.walk_dict(array, pending_subtrees) - } else if id == RunEnd::ID { + } else if *id == RunEnd::ID { self.walk_runend(array, pending_subtrees) - } else if id == Primitive::ID { + } else if *id == Primitive::ID { self.walk_primitive(array) - } else if id == Slice::ID { + } else if *id == Slice::ID { self.walk_slice(array, pending_subtrees) - } else if id == Sequence::ID { + } else if *id == Sequence::ID { self.walk_sequence(array) } else { vortex_bail!( diff --git a/vortex-cuda/src/executor.rs b/vortex-cuda/src/executor.rs index 1d6d1db76c3..371e1a6ce6a 100644 --- a/vortex-cuda/src/executor.rs +++ b/vortex-cuda/src/executor.rs @@ -352,7 +352,7 @@ pub trait CudaArrayExt { impl CudaArrayExt for ArrayRef { #[allow(clippy::unwrap_in_result, clippy::unwrap_used)] async fn execute_cuda(self, ctx: &mut CudaExecutionCtx) -> VortexResult { - if self.encoding_id() == Struct::ID { + if *self.encoding_id() == Struct::ID { let len = self.len(); let StructDataParts { fields, diff --git a/vortex-cuda/src/hybrid_dispatch/mod.rs b/vortex-cuda/src/hybrid_dispatch/mod.rs index 61708c8248a..14f191dd7e2 100644 --- a/vortex-cuda/src/hybrid_dispatch/mod.rs +++ b/vortex-cuda/src/hybrid_dispatch/mod.rs @@ -99,7 +99,7 @@ pub async fn try_gpu_dispatch( DispatchPlan::Unfused => { // Unfused kernel dispatch fallback. ctx.cuda_session() - .kernel(&array.encoding_id()) + .kernel(array.encoding_id()) .ok_or_else(|| { vortex_err!("No CUDA kernel for encoding {:?}", array.encoding_id()) })? diff --git a/vortex-cuda/src/layout.rs b/vortex-cuda/src/layout.rs index 0a8efcaac5f..e5c7be1117d 100644 --- a/vortex-cuda/src/layout.rs +++ b/vortex-cuda/src/layout.rs @@ -549,7 +549,7 @@ fn extract_constant_buffers(chunk: &ArrayRef) -> Vec { let mut buffer_idx = 0u32; for array in chunk.depth_first_traversal() { let n = array.nbuffers(); - if array.encoding_id() == Constant::ID { + if *array.encoding_id() == Constant::ID { for buf in array.buffers() { result.push(InlinedBuffer { buffer_index: buffer_idx, diff --git a/vortex-test/compat-gen/src/fixtures/mod.rs b/vortex-test/compat-gen/src/fixtures/mod.rs index edf0dd47e37..6ae8fa73346 100644 --- a/vortex-test/compat-gen/src/fixtures/mod.rs +++ b/vortex-test/compat-gen/src/fixtures/mod.rs @@ -168,8 +168,8 @@ pub fn check_expected_encodings( let mut found: Vec = Vec::new(); for node in array.depth_first_traversal() { let id = node.encoding_id(); - if !found.contains(&id) { - found.push(id); + if !found.contains(id) { + found.push(id.clone()); } }