diff --git a/crates/geo_filters/src/diff_count.rs b/crates/geo_filters/src/diff_count.rs index 886cfc8..8fc3db2 100644 --- a/crates/geo_filters/src/diff_count.rs +++ b/crates/geo_filters/src/diff_count.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; use std::cmp::Ordering; use std::hash::BuildHasher as _; use std::mem::{size_of, size_of_val}; +use std::ops::Deref as _; use crate::config::{ count_ones_from_bitchunks, count_ones_from_msb_and_lsb, iter_bit_chunks, iter_ones, @@ -77,7 +78,7 @@ impl> std::fmt::Debug for GeoDiffCount<'_, C> { } } -impl> GeoDiffCount<'_, C> { +impl<'a, C: GeoConfig> GeoDiffCount<'a, C> { pub fn new(config: C) -> Self { Self { config, @@ -204,6 +205,8 @@ impl> GeoDiffCount<'_, C> { /// that makes the cost of the else case negligible. fn xor_bit(&mut self, bucket: C::BucketType) { if bucket.into_usize() < self.lsb.num_bits() { + // The bit being toggled is within our LSB bit vector + // so toggle it directly. self.lsb.toggle(bucket.into_usize()); } else { let msb = self.msb.to_mut(); @@ -224,15 +227,17 @@ impl> GeoDiffCount<'_, C> { Err(idx) => { msb.insert(idx, bucket); if msb.len() > self.config.max_msb_len() { + // We have too many values in the MSB sparse index vector, + // let's move the smalles MSB value into the LSB bit vector let smallest = msb .pop() .expect("we should have at least one element!") .into_usize(); - // ensure vector covers smallest let new_smallest = msb .last() .expect("should have at least one element") .into_usize(); + // ensure LSB bit vector has the space for `smallest` self.lsb.resize(new_smallest); self.lsb.toggle(smallest); } else if msb.len() == self.config.max_msb_len() { @@ -282,6 +287,57 @@ impl> GeoDiffCount<'_, C> { self.lsb.num_bits(), ); } + + // Serialization: + // + // Since most of our target platforms are little endian there are more optimised approaches + // for little endian platforms, just splatting the bytes into the writer. This is contrary + // to the usual "network endian" approach where big endian is the default, but most of our + // consumers are little endian so it makes sense for this to be the optimal approach. + // + // For now we do not support big endian platforms. In the future we might add a big endian + // platform specific implementation which is able to read the little endian serialized + // representation. For now, if you attempt to serialize a filter on a big endian platform + // you get a panic. + + /// Create a new [`GeoDiffCount`] from a slice of bytes + #[cfg(target_endian = "little")] + pub fn from_bytes(c: C, buf: &'a [u8]) -> Self { + if buf.is_empty() { + return Self::new(c); + } + // The number of most significant bits stores in the MSB sparse repr + let msb_len = (buf.len() / size_of::()).min(c.max_msb_len()); + let msb = unsafe { + std::mem::transmute::<&[u8], &[C::BucketType]>(std::slice::from_raw_parts( + buf.as_ptr(), + msb_len, + )) + }; + // The number of bytes representing the MSB - this is how many bytes we need to + // skip over to reach the LSB + let msb_bytes_len = msb_len * size_of::(); + Self { + config: c, + msb: Cow::Borrowed(msb), + lsb: BitVec::from_bytes(&buf[msb_bytes_len..]), + } + } + + #[cfg(target_endian = "little")] + pub fn write(&self, writer: &mut W) -> std::io::Result { + if self.msb.is_empty() { + return Ok(0); + } + let msb_buckets = self.msb.deref(); + let msb_bytes = unsafe { + std::slice::from_raw_parts(msb_buckets.as_ptr() as *const u8, size_of_val(msb_buckets)) + }; + writer.write_all(msb_bytes)?; + let mut bytes_written = msb_bytes.len(); + bytes_written += self.lsb.write(writer)?; + Ok(bytes_written) + } } /// Applies a repeated bit mask to the underlying filter. @@ -360,8 +416,10 @@ impl> Count for GeoDiffCount<'_, C> { #[cfg(test)] mod tests { + use std::io::Write; + use itertools::Itertools; - use rand::RngCore; + use rand::{rngs::StdRng, seq::IteratorRandom, RngCore}; use crate::{ build_hasher::UnstableDefaultBuildHasher, @@ -581,4 +639,62 @@ mod tests { iter_ones(self.bit_chunks().peekable()).map(C::BucketType::from_usize) } } + + #[test] + fn test_serialization_empty() { + let before = GeoDiffCount7::default(); + + let mut writer = vec![]; + before.write(&mut writer).unwrap(); + + assert_eq!(writer.len(), 0); + + let after = GeoDiffCount7::from_bytes(before.config.clone(), &writer); + + assert_eq!(before, after); + } + + // This helper exists in order to easily test serializing types with different + // bucket types in the MSB sparse bit field representation. See tests below. + #[cfg(target_endian = "little")] + fn serialization_round_trip + Default>(rnd: &mut StdRng) { + // Run 100 simulations of random values being put into + // a diff counter. "Serializing" to a vector to emulate + // writing to a disk, and then deserializing and asserting + // the filters are equal. + let mut before = GeoDiffCount::<'_, C>::default(); + // Select a random number of items to insert. + let items = (1..1000).choose(rnd).unwrap(); + for _ in 0..items { + before.push_hash(rnd.next_u64()); + } + let mut writer = vec![]; + // Insert some padding to emulate alignment issues with the slices. + // A previous version of this test never panicked even though we were + // violating the alignment preconditions for the `from_raw_parts` function. + let padding = [0_u8; 8]; + let pad_amount = (0..8).choose(rnd).unwrap(); + writer.write_all(&padding[..pad_amount]).unwrap(); + before.write(&mut writer).unwrap(); + let after = GeoDiffCount::<'_, C>::from_bytes(before.config.clone(), &writer[pad_amount..]); + assert_eq!(before, after); + } + + #[test] + #[cfg(target_endian = "little")] + fn test_serialization_round_trip_7() { + prng_test_harness(100, |rnd| { + // Uses a u16 for MSB buckets. + serialization_round_trip::(rnd); + }); + } + + #[test] + #[cfg(target_endian = "little")] + fn test_serialization_round_trip_13() { + prng_test_harness(100, |rnd| { + // Uses a u32 for MSB buckets. + serialization_round_trip::(rnd); + }); + } } diff --git a/crates/geo_filters/src/diff_count/bitvec.rs b/crates/geo_filters/src/diff_count/bitvec.rs index 707b1f0..f77323c 100644 --- a/crates/geo_filters/src/diff_count/bitvec.rs +++ b/crates/geo_filters/src/diff_count/bitvec.rs @@ -1,11 +1,11 @@ use std::borrow::Cow; use std::cmp::Ordering; use std::mem::{size_of, size_of_val}; -use std::ops::{Index, Range}; +use std::ops::{Deref as _, Index, Range}; -use crate::config::BitChunk; use crate::config::IsBucketType; use crate::config::BITS_PER_BLOCK; +use crate::config::{BitChunk, BYTES_PER_BLOCK}; /// A bit vector where every bit occupies exactly one bit (in contrast to `Vec` where each /// bit consumes 1 byte). It only implements the minimum number of operations that we need for our @@ -72,6 +72,10 @@ impl BitVec<'_> { self.num_bits } + pub fn is_empty(&self) -> bool { + self.num_bits() == 0 + } + /// Tests the bit specified by the provided zero-based bit position. pub fn test_bit(&self, index: usize) -> bool { assert!(index < self.num_bits); @@ -133,6 +137,53 @@ impl BitVec<'_> { let Self { num_bits, blocks } = self; size_of_val(num_bits) + blocks.len() * size_of::() } + + #[cfg(target_endian = "little")] + pub fn from_bytes(mut buf: &[u8]) -> Self { + if buf.is_empty() { + return Self::default(); + } + // The first byte of the serialized BitVec is used to indicate how many + // of the bits in the left-most u64 block are *unoccupied*. + // See [`BitVec::write`] implementation for how this is done. + assert!( + buf[0] < 64, + "Number of unoccupied bits should be <64, got {}", + buf[0] + ); + let num_bits = (buf.len() - 1) * 8 - buf[0] as usize; + buf = &buf[1..]; + assert_eq!( + buf.len() % BYTES_PER_BLOCK, + 0, + "buffer should be a multiple of 8 bytes, got {}", + buf.len() + ); + let blocks = unsafe { + std::mem::transmute::<&[u8], &[u64]>(std::slice::from_raw_parts( + buf.as_ptr(), + buf.len() / BYTES_PER_BLOCK, + )) + }; + let blocks = Cow::Borrowed(blocks); + Self { num_bits, blocks } + } + + #[cfg(target_endian = "little")] + pub fn write(&self, writer: &mut W) -> std::io::Result { + if self.is_empty() { + return Ok(0); + } + // First serialize the number of unoccupied bits in the last block as one byte. + let unoccupied_bits = 63 - ((self.num_bits - 1) % 64) as u8; + writer.write_all(&[unoccupied_bits])?; + let blocks = self.blocks.deref(); + let block_bytes = unsafe { + std::slice::from_raw_parts(blocks.as_ptr() as *const u8, blocks.len() * BYTES_PER_BLOCK) + }; + writer.write_all(block_bytes)?; + Ok(block_bytes.len() + 1) + } } impl Index for BitVec<'_> {