Skip to content

Commit 08316ba

Browse files
authored
Merge pull request #72 from github/sc-20250724-geofilter-deterministic-tests
Deterministic RNG test harness and address bug in `BitVec`
2 parents 6a72711 + a469fde commit 08316ba

File tree

7 files changed

+180
-133
lines changed

7 files changed

+180
-133
lines changed

crates/geo_filters/src/config.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,15 @@ pub(crate) fn take_ref<I: Iterator>(iter: &mut I, n: usize) -> impl Iterator<Ite
353353

354354
#[cfg(test)]
355355
pub(crate) mod tests {
356-
use rand::{RngCore, SeedableRng};
356+
use rand::{rngs::StdRng, RngCore};
357357

358358
use crate::{Count, Method};
359359

360360
/// Runs estimation trials and returns the average precision and variance.
361-
pub(crate) fn test_estimate<M: Method, C: Count<M>>(f: impl Fn() -> C) -> (f32, f32) {
362-
let mut rnd = rand::rngs::StdRng::from_os_rng();
361+
pub(crate) fn test_estimate<M: Method, C: Count<M>>(
362+
rnd: &mut StdRng,
363+
f: impl Fn() -> C,
364+
) -> (f32, f32) {
363365
let cnt = 10000usize;
364366
let mut avg_precision = 0.0;
365367
let mut avg_var = 0.0;

crates/geo_filters/src/config/lookup.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,29 +45,36 @@ impl HashToBucketLookup {
4545

4646
#[cfg(test)]
4747
mod tests {
48-
use rand::{RngCore, SeedableRng};
48+
use rand::{rngs::StdRng, RngCore};
4949

50-
use crate::config::{hash_to_bucket, phi_f64};
50+
use crate::{
51+
config::{hash_to_bucket, phi_f64},
52+
test_rng::prng_test_harness,
53+
};
5154

5255
use super::HashToBucketLookup;
5356

5457
#[test]
5558
fn test_lookup_7() {
56-
let var = lookup_random_hashes_variance::<7>(1 << 16);
57-
assert!(var < 1e-4, "variance {var} too large");
59+
prng_test_harness(1, |rnd| {
60+
let var = lookup_random_hashes_variance::<7>(rnd, 1 << 16);
61+
assert!(var < 1e-4, "variance {var} too large");
62+
});
5863
}
5964

6065
#[test]
6166
fn test_lookup_13() {
62-
let var = lookup_random_hashes_variance::<13>(1 << 16);
63-
assert!(var < 1e-4, "variance {var} too large");
67+
prng_test_harness(1, |rnd| {
68+
let var = lookup_random_hashes_variance::<13>(rnd, 1 << 16);
69+
assert!(var < 1e-4, "variance {var} too large");
70+
});
6471
}
6572

66-
fn lookup_random_hashes_variance<const B: usize>(n: u64) -> f64 {
73+
fn lookup_random_hashes_variance<const B: usize>(rnd: &mut StdRng, n: u64) -> f64 {
6774
let phi = phi_f64(B);
6875
let buckets = HashToBucketLookup::new(B);
76+
6977
let mut var = 0.0;
70-
let mut rnd = rand::rngs::StdRng::from_os_rng();
7178
for _ in 0..n {
7279
let hash = rnd.next_u64();
7380
let estimate = buckets.lookup(hash) as f64;

crates/geo_filters/src/diff_count.rs

Lines changed: 85 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,17 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
9595
/// having to construct another iterator with the remaining `BitChunk`s.
9696
fn from_bit_chunks<I: Iterator<Item = BitChunk>>(config: C, chunks: I) -> Self {
9797
let mut ones = iter_ones::<C::BucketType, _>(chunks.peekable());
98-
9998
let mut msb = Vec::default();
10099
take_ref(&mut ones, config.max_msb_len() - 1).for_each(|bucket| {
101100
msb.push(bucket);
102101
});
103102
let smallest_msb = ones
104103
.next()
105-
.map(|bucket| {
106-
msb.push(bucket);
107-
bucket
104+
.inspect(|bucket| {
105+
msb.push(*bucket);
108106
})
109107
.unwrap_or_default();
110-
111108
let lsb = BitVec::from_bit_chunks(ones.into_bitchunks(), smallest_msb.into_usize());
112-
113109
let result = Self {
114110
config,
115111
msb: Cow::from(msb),
@@ -214,17 +210,16 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
214210
match msb.binary_search_by(|k| bucket.cmp(k)) {
215211
Ok(idx) => {
216212
msb.remove(idx);
217-
let (first, second) = {
213+
let first = {
218214
let mut lsb = iter_ones(self.lsb.bit_chunks().peekable());
219-
(lsb.next(), lsb.next())
215+
lsb.next()
220216
};
221-
let new_smallest = if let Some(smallest) = first {
217+
if let Some(smallest) = first {
222218
msb.push(C::BucketType::from_usize(smallest));
223-
second.map(|_| smallest).unwrap_or(0)
219+
self.lsb.resize(smallest);
224220
} else {
225-
0
221+
self.lsb.resize(0);
226222
};
227-
self.lsb.resize(new_smallest);
228223
}
229224
Err(idx) => {
230225
msb.insert(idx, bucket);
@@ -240,6 +235,12 @@ impl<C: GeoConfig<Diff>> GeoDiffCount<'_, C> {
240235
.into_usize();
241236
self.lsb.resize(new_smallest);
242237
self.lsb.toggle(smallest);
238+
} else if msb.len() == self.config.max_msb_len() {
239+
let smallest = msb
240+
.last()
241+
.expect("should have at least one element")
242+
.into_usize();
243+
self.lsb.resize(smallest);
243244
}
244245
}
245246
}
@@ -360,11 +361,12 @@ impl<C: GeoConfig<Diff>> Count<Diff> for GeoDiffCount<'_, C> {
360361
#[cfg(test)]
361362
mod tests {
362363
use itertools::Itertools;
363-
use rand::{RngCore, SeedableRng};
364+
use rand::RngCore;
364365

365366
use crate::{
366367
build_hasher::UnstableDefaultBuildHasher,
367368
config::{iter_ones, tests::test_estimate, FixedConfig},
369+
test_rng::prng_test_harness,
368370
};
369371

370372
use super::*;
@@ -435,57 +437,62 @@ mod tests {
435437

436438
#[test]
437439
fn test_estimate_fast() {
438-
let (avg_precision, avg_var) = test_estimate(GeoDiffCount7::default);
439-
println!(
440-
"avg precision: {} with standard deviation: {}",
441-
avg_precision,
442-
avg_var.sqrt(),
443-
);
444-
// Make sure that the estimate converges to the correct value.
445-
assert!(avg_precision.abs() < 0.04);
446-
// We should theoretically achieve a standard deviation of about 0.12
447-
assert!(avg_var.sqrt() < 0.14);
440+
prng_test_harness(1, |rnd| {
441+
let (avg_precision, avg_var) = test_estimate(rnd, GeoDiffCount7::default);
442+
println!(
443+
"avg precision: {} with standard deviation: {}",
444+
avg_precision,
445+
avg_var.sqrt(),
446+
);
447+
// Make sure that the estimate converges to the correct value.
448+
assert!(avg_precision.abs() < 0.04);
449+
// We should theoretically achieve a standard deviation of about 0.12
450+
assert!(avg_var.sqrt() < 0.14);
451+
})
448452
}
449453

450454
#[test]
451455
fn test_estimate_fast_low_precision() {
452-
let (avg_precision, avg_var) = test_estimate(GeoDiffCount7_50::default);
453-
println!(
454-
"avg precision: {} with standard deviation: {}",
455-
avg_precision,
456-
avg_var.sqrt(),
457-
);
458-
// Make sure that the estimate converges to the correct value.
459-
assert!(avg_precision.abs() < 0.15);
460-
// We should theoretically achieve a standard deviation of about 0.25
461-
assert!(avg_var.sqrt() < 0.4);
456+
prng_test_harness(1, |rnd| {
457+
let (avg_precision, avg_var) = test_estimate(rnd, GeoDiffCount7_50::default);
458+
println!(
459+
"avg precision: {} with standard deviation: {}",
460+
avg_precision,
461+
avg_var.sqrt(),
462+
);
463+
// Make sure that the estimate converges to the correct value.
464+
assert!(avg_precision.abs() < 0.15);
465+
// We should theoretically achieve a standard deviation of about 0.25
466+
assert!(avg_var.sqrt() < 0.4);
467+
});
462468
}
463469

464470
#[test]
465471
fn test_estimate_diff_size_fast() {
466-
let mut rnd = rand::rngs::StdRng::from_os_rng();
467-
let mut a_p = GeoDiffCount7_50::default();
468-
let mut a_hp = GeoDiffCount7::default();
469-
let mut b_p = GeoDiffCount7_50::default();
470-
let mut b_hp = GeoDiffCount7::default();
471-
for _ in 0..10000 {
472-
let hash = rnd.next_u64();
473-
a_p.push_hash(hash);
474-
a_hp.push_hash(hash);
475-
}
476-
for _ in 0..1000 {
477-
let hash = rnd.next_u64();
478-
b_p.push_hash(hash);
479-
b_hp.push_hash(hash);
480-
}
481-
let c_p = xor(&a_p, &b_p);
482-
let c_hp = xor(&a_hp, &b_hp);
472+
prng_test_harness(1, |rnd| {
473+
let mut a_p = GeoDiffCount7_50::default();
474+
let mut a_hp = GeoDiffCount7::default();
475+
let mut b_p = GeoDiffCount7_50::default();
476+
let mut b_hp = GeoDiffCount7::default();
477+
for _ in 0..10000 {
478+
let hash = rnd.next_u64();
479+
a_p.push_hash(hash);
480+
a_hp.push_hash(hash);
481+
}
482+
for _ in 0..1000 {
483+
let hash = rnd.next_u64();
484+
b_p.push_hash(hash);
485+
b_hp.push_hash(hash);
486+
}
487+
let c_p = xor(&a_p, &b_p);
488+
let c_hp = xor(&a_hp, &b_hp);
483489

484-
assert_eq!(c_p.size(), a_p.size_with_sketch(&b_p));
485-
assert_eq!(c_p.size(), b_p.size_with_sketch(&a_p));
490+
assert_eq!(c_p.size(), a_p.size_with_sketch(&b_p));
491+
assert_eq!(c_p.size(), b_p.size_with_sketch(&a_p));
486492

487-
assert_eq!(c_hp.size(), a_hp.size_with_sketch(&b_hp));
488-
assert_eq!(c_hp.size(), b_hp.size_with_sketch(&a_hp));
493+
assert_eq!(c_hp.size(), a_hp.size_with_sketch(&b_hp));
494+
assert_eq!(c_hp.size(), b_hp.size_with_sketch(&a_hp));
495+
});
489496
}
490497

491498
#[test]
@@ -517,45 +524,39 @@ mod tests {
517524

518525
#[test]
519526
fn test_xor_plus_mask() {
520-
let mut rnd = rand::rngs::StdRng::from_os_rng();
521-
let mask_size = 12;
522-
let mask = 0b100001100000;
523-
let mut a = GeoDiffCount7::default();
524-
for _ in 0..10000 {
525-
a.xor_bit(a.config.hash_to_bucket(rnd.next_u64()));
526-
}
527-
let mut expected = GeoDiffCount7::default();
528-
let mut b = a.clone();
529-
for _ in 0..1000 {
530-
let hash = rnd.next_u64();
531-
b.xor_bit(b.config.hash_to_bucket(hash));
532-
expected.xor_bit(expected.config.hash_to_bucket(hash));
533-
assert_eq!(expected, xor(&a, &b));
534-
535-
let masked_a = masked(&a, mask, mask_size);
536-
let masked_b = masked(&b, mask, mask_size);
537-
let masked_expected = masked(&expected, mask, mask_size);
538-
// FIXME: test failed once with:
539-
// left: ~12.37563 (msb: [390, 334, 263, 242, 222, 215, 164, 148, 100, 97, 66, 36], |lsb|: 36)
540-
// right: ~12.37563 (msb: [390, 334, 263, 242, 222, 215, 164, 148, 100, 97, 66, 36], |lsb|: 0)
541-
assert_eq!(masked_expected, xor(&masked_a, &masked_b));
542-
}
527+
prng_test_harness(10, |rnd| {
528+
let mask_size = 12;
529+
let mask = 0b100001100000;
530+
let mut a = GeoDiffCount7::default();
531+
for _ in 0..10000 {
532+
a.xor_bit(a.config.hash_to_bucket(rnd.next_u64()));
533+
}
534+
let mut expected = GeoDiffCount7::default();
535+
let mut b = a.clone();
536+
for _ in 0..1000 {
537+
let hash = rnd.next_u64();
538+
b.xor_bit(b.config.hash_to_bucket(hash));
539+
expected.xor_bit(expected.config.hash_to_bucket(hash));
540+
assert_eq!(expected, xor(&a, &b));
541+
let masked_a = masked(&a, mask, mask_size);
542+
let masked_b = masked(&b, mask, mask_size);
543+
let masked_expected = masked(&expected, mask, mask_size);
544+
assert_eq!(masked_expected, xor(&masked_a, &masked_b));
545+
}
546+
});
543547
}
544548

545549
#[test]
546550
fn test_bit_chunks() {
547-
let mut rnd = rand::rngs::StdRng::from_os_rng();
548-
for _ in 0..100 {
551+
prng_test_harness(100, |rnd| {
549552
let mut expected = GeoDiffCount7::default();
550553
for _ in 0..1000 {
551554
expected.push_hash(rnd.next_u64());
552555
}
553-
let actual = GeoDiffCount::from_bit_chunks(
554-
expected.config.clone(),
555-
expected.bit_chunks().peekable(),
556-
);
556+
let actual =
557+
GeoDiffCount::from_bit_chunks(expected.config.clone(), expected.bit_chunks());
557558
assert_eq!(expected, actual);
558-
}
559+
});
559560
}
560561

561562
#[test]

crates/geo_filters/src/diff_count/bitvec.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use std::borrow::Cow;
22
use std::cmp::Ordering;
3-
use std::iter::Peekable;
43
use std::mem::{size_of, size_of_val};
54
use std::ops::{Index, Range};
65

@@ -12,7 +11,7 @@ use crate::config::BITS_PER_BLOCK;
1211
/// bit consumes 1 byte). It only implements the minimum number of operations that we need for our
1312
/// GeoDiffCount implementation. In particular it supports xor-ing of two bit vectors and
1413
/// iterating through one bits.
15-
#[derive(Clone, Default, Debug, PartialEq, Eq)]
14+
#[derive(Clone, Default, Debug, Eq, PartialEq)]
1615
pub(crate) struct BitVec<'a> {
1716
num_bits: usize,
1817
blocks: Cow<'a, [u64]>,
@@ -37,15 +36,7 @@ impl BitVec<'_> {
3736
/// Takes an iterator of `BitChunk` items as input and returns the corresponding `BitVec`.
3837
/// The order of `BitChunk`s doesn't matter for this function and `BitChunk` may be hitting
3938
/// the same block. In this case, the function will simply xor them together.
40-
///
41-
/// NOTE: If the bitchunks iterator is empty, the result is NOT sized to `num_bits` but will
42-
/// be EMPTY instead.
43-
pub fn from_bit_chunks<I: Iterator<Item = BitChunk>>(
44-
mut chunks: Peekable<I>,
45-
num_bits: usize,
46-
) -> Self {
47-
// if there are no chunks, we keep the size zero
48-
let num_bits = chunks.peek().map(|_| num_bits).unwrap_or_default();
39+
pub fn from_bit_chunks<I: Iterator<Item = BitChunk>>(chunks: I, num_bits: usize) -> Self {
4940
let mut result = Self::default();
5041
result.resize(num_bits);
5142
let blocks = result.blocks.to_mut();

0 commit comments

Comments
 (0)