Skip to content
58 changes: 32 additions & 26 deletions include/cuco/detail/bloom_filter/arrow_filter_policy.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -83,10 +83,10 @@ namespace cuco::detail {
template <class Key, template <typename> class XXHash64>
class arrow_filter_policy {
public:
using hasher = XXHash64<Key>; ///< 64-bit XXHash hasher for Arrow bloom filter policy
using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy
using key_type = Key; ///< Hash function input type
using hash_value_type = std::uint64_t; ///< hash function output type
using hasher = XXHash64<Key>; ///< 64-bit XXHash hasher for Arrow bloom filter policy
using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy
using key_type = Key; ///< Hash function input type
using hash_result_type = std::uint64_t; ///< hash function output type

static constexpr uint32_t bits_set_per_block = 8; ///< hardcoded bits set per Arrow filter block
static constexpr uint32_t words_per_block = 8; ///< hardcoded words per Arrow filter block
Expand All @@ -99,21 +99,6 @@ class arrow_filter_policy {
(max_arrow_filter_bytes /
bytes_per_filter_block); ///< Max sub-filter blocks allowed in Arrow bloom filter

private:
// Arrow's block-based bloom filter algorithm needs these eight odd SALT values to calculate
// eight indexes of bit to set, one bit in each 32-bit (uint32_t) word.
__device__ static constexpr cuda::std::array<std::uint32_t, 8> SALT()
{
return {0x47b6137bU,
0x44974d91U,
0x8824ad5bU,
0xa2b7289dU,
0x705495c7U,
0x2df1424bU,
0x9efc4947U,
0x5c6bfb31U};
}

public:
/**
* @brief Constructs the `arrow_filter_policy` object.
Expand All @@ -133,7 +118,7 @@ class arrow_filter_policy {
*
* @return The hash value of the key
*/
__device__ constexpr hash_value_type hash(key_type const& key) const { return hash_(key); }
__device__ constexpr hash_result_type hash(key_type const& key) const { return hash_(key); }

/**
* @brief Determines the filter block a key is added into.
Expand All @@ -150,7 +135,7 @@ class arrow_filter_policy {
* @return The block index for the given key's hash value
*/
template <class Extent>
__device__ constexpr auto block_index(hash_value_type hash, Extent num_blocks) const
__device__ constexpr auto block_index(hash_result_type hash, Extent num_blocks) const
{
constexpr auto hash_bits = cuda::std::numeric_limits<word_type>::digits;
// TODO: assert if num_blocks > max_filter_blocks
Expand All @@ -168,12 +153,33 @@ class arrow_filter_policy {
*
* @return The bit pattern for the word/segment in the filter block
*/
__device__ constexpr word_type word_pattern(hash_value_type hash, std::uint32_t word_index) const
__device__ constexpr word_type word_pattern(hash_result_type hash, std::uint32_t word_index) const
{
// SALT array to calculate bit indexes for the current word
auto constexpr salt = SALT();
word_type const key = static_cast<word_type>(hash);
return word_type{1} << ((key * salt[word_index]) >> 27);
std::uint32_t salt;

// Basically a switch (word_index) { case 0-7 ... }
// First split: 0..3 versus 4..7.
if (word_index < 4) {
// For indices 0..3, further split into 0..1 and 2..3.
if (word_index < 2) {
// word_index is 0 or 1.
salt = (word_index == 0) ? 0x47b6137bU : 0x44974d91U;
} else {
// word_index is 2 or 3.
salt = (word_index == 2) ? 0x8824ad5bU : 0xa2b7289dU;
}
} else {
// For indices 4..7, further split into 4..5 and 6..7.
if (word_index < 6) {
// word_index is 4 or 5.
salt = (word_index == 4) ? 0x705495c7U : 0x2df1424bU;
} else {
// word_index is 6 or 7.
salt = (word_index == 6) ? 0x9efc4947U : 0x5c6bfb31U;
}
}
return word_type{1} << ((key * salt) >> 27);
}

private:
Expand Down
Loading