Skip to content

Commit ad1e547

Browse files
authored
Merge pull request #71 from github/sc-20250723-serialize
Basic serialization and deserialization of diff filters
2 parents 08316ba + 14c5b08 commit ad1e547

File tree

2 files changed

+172
-5
lines changed

2 files changed

+172
-5
lines changed

crates/geo_filters/src/diff_count.rs

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::borrow::Cow;
44
use std::cmp::Ordering;
55
use std::hash::BuildHasher as _;
66
use std::mem::{size_of, size_of_val};
7+
use std::ops::Deref as _;
78

89
use crate::config::{
910
count_ones_from_bitchunks, count_ones_from_msb_and_lsb, iter_bit_chunks, iter_ones,
@@ -77,7 +78,7 @@ impl<C: GeoConfig<Diff>> std::fmt::Debug for GeoDiffCount<'_, C> {
7778
}
7879
}
7980

80-
impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
81+
impl<'a, C: GeoConfig<Diff>> GeoDiffCount<'a, C> {
8182
pub fn new(config: C) -> Self {
8283
Self {
8384
config,
@@ -204,6 +205,8 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
204205
/// that makes the cost of the else case negligible.
205206
fn xor_bit(&mut self, bucket: C::BucketType) {
206207
if bucket.into_usize() < self.lsb.num_bits() {
208+
// The bit being toggled is within our LSB bit vector
209+
// so toggle it directly.
207210
self.lsb.toggle(bucket.into_usize());
208211
} else {
209212
let msb = self.msb.to_mut();
@@ -224,15 +227,17 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
224227
Err(idx) => {
225228
msb.insert(idx, bucket);
226229
if msb.len() > self.config.max_msb_len() {
230+
// We have too many values in the MSB sparse index vector,
231+
// let's move the smalles MSB value into the LSB bit vector
227232
let smallest = msb
228233
.pop()
229234
.expect("we should have at least one element!")
230235
.into_usize();
231-
// ensure vector covers smallest
232236
let new_smallest = msb
233237
.last()
234238
.expect("should have at least one element")
235239
.into_usize();
240+
// ensure LSB bit vector has the space for `smallest`
236241
self.lsb.resize(new_smallest);
237242
self.lsb.toggle(smallest);
238243
} else if msb.len() == self.config.max_msb_len() {
@@ -282,6 +287,57 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
282287
self.lsb.num_bits(),
283288
);
284289
}
290+
291+
// Serialization:
292+
//
293+
// Since most of our target platforms are little endian there are more optimised approaches
294+
// for little endian platforms, just splatting the bytes into the writer. This is contrary
295+
// to the usual "network endian" approach where big endian is the default, but most of our
296+
// consumers are little endian so it makes sense for this to be the optimal approach.
297+
//
298+
// For now we do not support big endian platforms. In the future we might add a big endian
299+
// platform specific implementation which is able to read the little endian serialized
300+
// representation. For now, if you attempt to serialize a filter on a big endian platform
301+
// you get a panic.
302+
303+
/// Create a new [`GeoDiffCount`] from a slice of bytes
304+
#[cfg(target_endian = "little")]
305+
pub fn from_bytes(c: C, buf: &'a [u8]) -> Self {
306+
if buf.is_empty() {
307+
return Self::new(c);
308+
}
309+
// The number of most significant bits stores in the MSB sparse repr
310+
let msb_len = (buf.len() / size_of::<C::BucketType>()).min(c.max_msb_len());
311+
let msb = unsafe {
312+
std::mem::transmute::<&[u8], &[C::BucketType]>(std::slice::from_raw_parts(
313+
buf.as_ptr(),
314+
msb_len,
315+
))
316+
};
317+
// The number of bytes representing the MSB - this is how many bytes we need to
318+
// skip over to reach the LSB
319+
let msb_bytes_len = msb_len * size_of::<C::BucketType>();
320+
Self {
321+
config: c,
322+
msb: Cow::Borrowed(msb),
323+
lsb: BitVec::from_bytes(&buf[msb_bytes_len..]),
324+
}
325+
}
326+
327+
#[cfg(target_endian = "little")]
328+
pub fn write<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
329+
if self.msb.is_empty() {
330+
return Ok(0);
331+
}
332+
let msb_buckets = self.msb.deref();
333+
let msb_bytes = unsafe {
334+
std::slice::from_raw_parts(msb_buckets.as_ptr() as *const u8, size_of_val(msb_buckets))
335+
};
336+
writer.write_all(msb_bytes)?;
337+
let mut bytes_written = msb_bytes.len();
338+
bytes_written += self.lsb.write(writer)?;
339+
Ok(bytes_written)
340+
}
285341
}
286342

287343
/// Applies a repeated bit mask to the underlying filter.
@@ -360,8 +416,10 @@ impl<C: GeoConfig<Diff>> Count<Diff> for GeoDiffCount<'_, C> {
360416

361417
#[cfg(test)]
362418
mod tests {
419+
use std::io::Write;
420+
363421
use itertools::Itertools;
364-
use rand::RngCore;
422+
use rand::{rngs::StdRng, seq::IteratorRandom, RngCore};
365423

366424
use crate::{
367425
build_hasher::UnstableDefaultBuildHasher,
@@ -581,4 +639,62 @@ mod tests {
581639
iter_ones(self.bit_chunks().peekable()).map(C::BucketType::from_usize)
582640
}
583641
}
642+
643+
#[test]
644+
fn test_serialization_empty() {
645+
let before = GeoDiffCount7::default();
646+
647+
let mut writer = vec![];
648+
before.write(&mut writer).unwrap();
649+
650+
assert_eq!(writer.len(), 0);
651+
652+
let after = GeoDiffCount7::from_bytes(before.config.clone(), &writer);
653+
654+
assert_eq!(before, after);
655+
}
656+
657+
// This helper exists in order to easily test serializing types with different
658+
// bucket types in the MSB sparse bit field representation. See tests below.
659+
#[cfg(target_endian = "little")]
660+
fn serialization_round_trip<C: GeoConfig<Diff> + Default>(rnd: &mut StdRng) {
661+
// Run 100 simulations of random values being put into
662+
// a diff counter. "Serializing" to a vector to emulate
663+
// writing to a disk, and then deserializing and asserting
664+
// the filters are equal.
665+
let mut before = GeoDiffCount::<'_, C>::default();
666+
// Select a random number of items to insert.
667+
let items = (1..1000).choose(rnd).unwrap();
668+
for _ in 0..items {
669+
before.push_hash(rnd.next_u64());
670+
}
671+
let mut writer = vec![];
672+
// Insert some padding to emulate alignment issues with the slices.
673+
// A previous version of this test never panicked even though we were
674+
// violating the alignment preconditions for the `from_raw_parts` function.
675+
let padding = [0_u8; 8];
676+
let pad_amount = (0..8).choose(rnd).unwrap();
677+
writer.write_all(&padding[..pad_amount]).unwrap();
678+
before.write(&mut writer).unwrap();
679+
let after = GeoDiffCount::<'_, C>::from_bytes(before.config.clone(), &writer[pad_amount..]);
680+
assert_eq!(before, after);
681+
}
682+
683+
#[test]
684+
#[cfg(target_endian = "little")]
685+
fn test_serialization_round_trip_7() {
686+
prng_test_harness(100, |rnd| {
687+
// Uses a u16 for MSB buckets.
688+
serialization_round_trip::<GeoDiffConfig7>(rnd);
689+
});
690+
}
691+
692+
#[test]
693+
#[cfg(target_endian = "little")]
694+
fn test_serialization_round_trip_13() {
695+
prng_test_harness(100, |rnd| {
696+
// Uses a u32 for MSB buckets.
697+
serialization_round_trip::<GeoDiffConfig13>(rnd);
698+
});
699+
}
584700
}

crates/geo_filters/src/diff_count/bitvec.rs

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use std::borrow::Cow;
22
use std::cmp::Ordering;
33
use std::mem::{size_of, size_of_val};
4-
use std::ops::{Index, Range};
4+
use std::ops::{Deref as _, Index, Range};
55

6-
use crate::config::BitChunk;
76
use crate::config::IsBucketType;
87
use crate::config::BITS_PER_BLOCK;
8+
use crate::config::{BitChunk, BYTES_PER_BLOCK};
99

1010
/// A bit vector where every bit occupies exactly one bit (in contrast to `Vec<bool>` where each
1111
/// bit consumes 1 byte). It only implements the minimum number of operations that we need for our
@@ -72,6 +72,10 @@ impl BitVec<'_> {
7272
self.num_bits
7373
}
7474

75+
pub fn is_empty(&self) -> bool {
76+
self.num_bits() == 0
77+
}
78+
7579
/// Tests the bit specified by the provided zero-based bit position.
7680
pub fn test_bit(&self, index: usize) -> bool {
7781
assert!(index < self.num_bits);
@@ -133,6 +137,53 @@ impl BitVec<'_> {
133137
let Self { num_bits, blocks } = self;
134138
size_of_val(num_bits) + blocks.len() * size_of::<u64>()
135139
}
140+
141+
#[cfg(target_endian = "little")]
142+
pub fn from_bytes(mut buf: &[u8]) -> Self {
143+
if buf.is_empty() {
144+
return Self::default();
145+
}
146+
// The first byte of the serialized BitVec is used to indicate how many
147+
// of the bits in the left-most u64 block are *unoccupied*.
148+
// See [`BitVec::write`] implementation for how this is done.
149+
assert!(
150+
buf[0] < 64,
151+
"Number of unoccupied bits should be <64, got {}",
152+
buf[0]
153+
);
154+
let num_bits = (buf.len() - 1) * 8 - buf[0] as usize;
155+
buf = &buf[1..];
156+
assert_eq!(
157+
buf.len() % BYTES_PER_BLOCK,
158+
0,
159+
"buffer should be a multiple of 8 bytes, got {}",
160+
buf.len()
161+
);
162+
let blocks = unsafe {
163+
std::mem::transmute::<&[u8], &[u64]>(std::slice::from_raw_parts(
164+
buf.as_ptr(),
165+
buf.len() / BYTES_PER_BLOCK,
166+
))
167+
};
168+
let blocks = Cow::Borrowed(blocks);
169+
Self { num_bits, blocks }
170+
}
171+
172+
#[cfg(target_endian = "little")]
173+
pub fn write<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
174+
if self.is_empty() {
175+
return Ok(0);
176+
}
177+
// First serialize the number of unoccupied bits in the last block as one byte.
178+
let unoccupied_bits = 63 - ((self.num_bits - 1) % 64) as u8;
179+
writer.write_all(&[unoccupied_bits])?;
180+
let blocks = self.blocks.deref();
181+
let block_bytes = unsafe {
182+
std::slice::from_raw_parts(blocks.as_ptr() as *const u8, blocks.len() * BYTES_PER_BLOCK)
183+
};
184+
writer.write_all(block_bytes)?;
185+
Ok(block_bytes.len() + 1)
186+
}
136187
}
137188

138189
impl Index<usize> for BitVec<'_> {

0 commit comments

Comments
 (0)