diff --git a/include/cuco/detail/bloom_filter/arrow_filter_policy.cuh b/include/cuco/detail/bloom_filter/arrow_filter_policy.cuh index bfe97cfaf..2f17fa726 100644 --- a/include/cuco/detail/bloom_filter/arrow_filter_policy.cuh +++ b/include/cuco/detail/bloom_filter/arrow_filter_policy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -83,10 +83,10 @@ namespace cuco::detail { template class XXHash64> class arrow_filter_policy { public: - using hasher = XXHash64; ///< 64-bit XXHash hasher for Arrow bloom filter policy - using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy - using key_type = Key; ///< Hash function input type - using hash_value_type = std::uint64_t; ///< hash function output type + using hasher = XXHash64; ///< 64-bit XXHash hasher for Arrow bloom filter policy + using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy + using key_type = Key; ///< Hash function input type + using hash_result_type = std::uint64_t; ///< hash function output type static constexpr uint32_t bits_set_per_block = 8; ///< hardcoded bits set per Arrow filter block static constexpr uint32_t words_per_block = 8; ///< hardcoded words per Arrow filter block @@ -99,21 +99,6 @@ class arrow_filter_policy { (max_arrow_filter_bytes / bytes_per_filter_block); ///< Max sub-filter blocks allowed in Arrow bloom filter - private: - // Arrow's block-based bloom filter algorithm needs these eight odd SALT values to calculate - // eight indexes of bit to set, one bit in each 32-bit (uint32_t) word. - __device__ static constexpr cuda::std::array SALT() - { - return {0x47b6137bU, - 0x44974d91U, - 0x8824ad5bU, - 0xa2b7289dU, - 0x705495c7U, - 0x2df1424bU, - 0x9efc4947U, - 0x5c6bfb31U}; - } - public: /** * @brief Constructs the `arrow_filter_policy` object. @@ -133,7 +118,7 @@ class arrow_filter_policy { * * @return The hash value of the key */ - __device__ constexpr hash_value_type hash(key_type const& key) const { return hash_(key); } + __device__ constexpr hash_result_type hash(key_type const& key) const { return hash_(key); } /** * @brief Determines the filter block a key is added into. @@ -150,7 +135,7 @@ class arrow_filter_policy { * @return The block index for the given key's hash value */ template - __device__ constexpr auto block_index(hash_value_type hash, Extent num_blocks) const + __device__ constexpr auto block_index(hash_result_type hash, Extent num_blocks) const { constexpr auto hash_bits = cuda::std::numeric_limits::digits; // TODO: assert if num_blocks > max_filter_blocks @@ -168,12 +153,33 @@ class arrow_filter_policy { * * @return The bit pattern for the word/segment in the filter block */ - __device__ constexpr word_type word_pattern(hash_value_type hash, std::uint32_t word_index) const + __device__ constexpr word_type word_pattern(hash_result_type hash, std::uint32_t word_index) const { - // SALT array to calculate bit indexes for the current word - auto constexpr salt = SALT(); word_type const key = static_cast(hash); - return word_type{1} << ((key * salt[word_index]) >> 27); + std::uint32_t salt; + + // Basically a switch (word_index) { case 0-7 ... } + // First split: 0..3 versus 4..7. + if (word_index < 4) { + // For indices 0..3, further split into 0..1 and 2..3. + if (word_index < 2) { + // word_index is 0 or 1. + salt = (word_index == 0) ? 0x47b6137bU : 0x44974d91U; + } else { + // word_index is 2 or 3. + salt = (word_index == 2) ? 0x8824ad5bU : 0xa2b7289dU; + } + } else { + // For indices 4..7, further split into 4..5 and 6..7. + if (word_index < 6) { + // word_index is 4 or 5. + salt = (word_index == 4) ? 0x705495c7U : 0x2df1424bU; + } else { + // word_index is 6 or 7. + salt = (word_index == 6) ? 0x9efc4947U : 0x5c6bfb31U; + } + } + return word_type{1} << ((key * salt) >> 27); } private: