Skip to content

Commit cefc6b2

Browse files
authored
introduce simple resource_pool (#123)
- for visited_list, give an implement of visited_list_pool - we will have some other object like aio_context... - currently hgraph use visitlist from hnswlib, now make it global Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
1 parent 206fa94 commit cefc6b2

5 files changed

Lines changed: 368 additions & 2 deletions

File tree

src/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ list (FILTER CPP_INDEX_SRCS EXCLUDE REGEX "_test.cpp")
1717
list (FILTER CPP_DATA_CELL_SRCS EXCLUDE REGEX "_test.cpp")
1818
list (FILTER CPP_ALGORITHM_SRCS EXCLUDE REGEX "_test.cpp")
1919

20-
set (VSAG_SRCS ${CPP_SRCS} ${CPP_FACTORY_SRCS} ${CPP_INDEX_SRCS} ${CPP_CONJUGATE_GRAPH_SRCS}
21-
${CPP_HNSWLIB_SRCS} ${CPP_DATA_CELL_SRCS} ${CPP_ALGORITHM_SRCS})
20+
file (GLOB CPP_UTILS_SRCS "*.cpp")
21+
list (FILTER CPP_UTILS_SRCS EXCLUDE REGEX "_test.cpp")
22+
23+
set (VSAG_SRCS ${CPP_SRCS} ${CPP_FACTORY_SRCS} ${CPP_INDEX_SRCS} ${CPP_CONJUGATE_GRAPH_SRCS} ${CPP_HNSWLIB_SRCS}
24+
${CPP_HNSWLIB_SRCS} ${CPP_DATA_CELL_SRCS} ${CPP_ALGORITHM_SRCS} ${CPP_UTILS_SRCS})
25+
2226
add_library (vsag SHARED ${VSAG_SRCS})
2327
add_library (vsag_static STATIC ${VSAG_SRCS})
2428

src/utils/resource_object.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
namespace vsag {
19+
20+
class ResourceObject {
21+
public:
22+
ResourceObject() = default;
23+
24+
virtual ~ResourceObject() = default;
25+
26+
/**
27+
* @brief Reset the resource to its initial state.
28+
*
29+
* This pure virtual function forces derived classes to provide an
30+
* implementation for resetting their specific resources. The reset
31+
* operation should revert the resource to a known, initial state,
32+
* freeing and reallocating memory if necessary, and ensuring that resources
33+
* are ready for reuse.
34+
*/
35+
virtual void
36+
Reset() = 0;
37+
};
38+
39+
} // namespace vsag

src/utils/resource_object_pool.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
#include <atomic>
19+
#include <cstdint>
20+
#include <functional>
21+
#include <memory>
22+
#include <mutex>
23+
#include <type_traits>
24+
25+
#include "resource_object.h"
26+
#include "typing.h"
27+
namespace vsag {
28+
29+
template <typename T,
30+
typename = typename std::enable_if<std::is_base_of<ResourceObject, T>::value>::type>
31+
class ResourceObjectPool {
32+
public:
33+
using ConstructFuncType = std::function<std::shared_ptr<T>()>;
34+
35+
public:
36+
template <typename... Args>
37+
explicit ResourceObjectPool(uint64_t init_size, Allocator* allocator, Args... args)
38+
: allocator_(allocator), pool_(allocator), pool_size_(init_size) {
39+
this->constructor_ = [=]() -> std::shared_ptr<T> { return std::make_shared<T>(args...); };
40+
this->resize(pool_size_);
41+
}
42+
43+
void
44+
SetConstructor(ConstructFuncType func) {
45+
this->constructor_ = func;
46+
{
47+
std::lock_guard<std::mutex> lock(mutex_);
48+
while (not pool_.empty()) {
49+
pool_.pop_front();
50+
}
51+
}
52+
this->resize(pool_size_);
53+
}
54+
55+
std::shared_ptr<T>
56+
TakeOne() {
57+
std::unique_lock<std::mutex> lock(mutex_);
58+
if (pool_.empty()) {
59+
lock.unlock();
60+
return this->constructor_();
61+
}
62+
std::shared_ptr<T> obj = pool_.front();
63+
pool_.pop_front();
64+
pool_size_--;
65+
lock.unlock();
66+
obj->Reset();
67+
return obj;
68+
}
69+
70+
void
71+
ReturnOne(std::shared_ptr<T>& obj) {
72+
std::lock_guard<std::mutex> lock(mutex_);
73+
pool_.emplace_back(obj);
74+
pool_size_++;
75+
}
76+
77+
[[nodiscard]] inline uint64_t
78+
GetSize() const {
79+
return this->pool_size_;
80+
}
81+
82+
private:
83+
inline void
84+
resize(uint64_t size) {
85+
std::lock_guard<std::mutex> lock(mutex_);
86+
int count = size - pool_.size();
87+
while (count > 0) {
88+
pool_.emplace_back(this->constructor_());
89+
--count;
90+
}
91+
while (count < 0) {
92+
pool_.pop_front();
93+
++count;
94+
}
95+
}
96+
97+
Deque<std::shared_ptr<T>> pool_;
98+
std::atomic<uint64_t> pool_size_;
99+
100+
ConstructFuncType constructor_{nullptr};
101+
std::mutex mutex_;
102+
Allocator* allocator_{nullptr};
103+
};
104+
105+
} // namespace vsag

src/utils/visited_list.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
#include <cstring>
18+
#include <limits>
19+
20+
#include "resource_object.h"
21+
#include "resource_object_pool.h"
22+
#include "typing.h"
23+
#include "vsag/allocator.h"
24+
25+
namespace vsag {
26+
27+
class VisitedList : public ResourceObject {
28+
public:
29+
using VisitedListType = uint16_t;
30+
31+
public:
32+
explicit VisitedList(InnerIdType max_size, Allocator* allocator)
33+
: max_size_(max_size), allocator_(allocator) {
34+
this->list_ = reinterpret_cast<VisitedListType*>(
35+
allocator_->Allocate((uint64_t)max_size * sizeof(VisitedListType)));
36+
memset(list_, 0, max_size_ * sizeof(VisitedListType));
37+
tag_ = 1;
38+
}
39+
40+
~VisitedList() override {
41+
allocator_->Deallocate(list_);
42+
}
43+
44+
inline void
45+
Set(const InnerIdType& id) {
46+
this->list_[id] = this->tag_;
47+
}
48+
49+
inline bool
50+
Get(const InnerIdType& id) {
51+
return this->list_[id] == this->tag_;
52+
}
53+
54+
inline void
55+
Prefetch(const InnerIdType& id) {
56+
return; // TODO(LHT) implement
57+
}
58+
59+
void
60+
Reset() override {
61+
if (tag_ == std::numeric_limits<VisitedListType>::max()) {
62+
memset(list_, 0, max_size_ * sizeof(VisitedListType));
63+
tag_ = 0;
64+
}
65+
++tag_;
66+
}
67+
68+
private:
69+
Allocator* const allocator_{nullptr};
70+
71+
VisitedListType* list_{nullptr};
72+
73+
VisitedListType tag_{1};
74+
75+
const InnerIdType max_size_{0};
76+
};
77+
78+
using VisitedListPool = ResourceObjectPool<VisitedList>;
79+
80+
} // namespace vsag

src/utils/visited_list_test.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include "visited_list.h"
17+
18+
#include <thread>
19+
20+
#include "catch2/catch_test_macros.hpp"
21+
#include "default_allocator.h"
22+
using namespace vsag;
23+
24+
TEST_CASE("test visited_list basic", "[ut][visited_list]") {
25+
auto allocator = std::make_shared<DefaultAllocator>();
26+
auto size = 10000;
27+
auto vl_ptr = std::make_shared<VisitedList>(size, allocator.get());
28+
29+
SECTION("test set & get normal") {
30+
int count = 500;
31+
std::unordered_set<InnerIdType> ids;
32+
for (int i = 0; i < count; ++i) {
33+
auto id = random() % size;
34+
ids.insert(id);
35+
vl_ptr->Set(id);
36+
}
37+
for (auto& id : ids) {
38+
REQUIRE(vl_ptr->Get(id));
39+
}
40+
41+
for (int i = 0; i < size; ++i) {
42+
if (ids.count(i) == 0) {
43+
REQUIRE(vl_ptr->Get(i) == false);
44+
}
45+
}
46+
}
47+
48+
SECTION("test reset") {
49+
int count = 500;
50+
std::unordered_set<InnerIdType> ids;
51+
for (int i = 0; i < count; ++i) {
52+
auto id = random() % size;
53+
ids.insert(id);
54+
vl_ptr->Set(id);
55+
}
56+
vl_ptr->Reset();
57+
for (auto& id : ids) {
58+
REQUIRE(vl_ptr->Get(id) == false);
59+
}
60+
}
61+
}
62+
63+
TEST_CASE("test visited_list_pool basic", "[ut][visited_list_pool]") {
64+
auto allocator = std::make_shared<DefaultAllocator>();
65+
auto init_size = 10;
66+
auto vl_size = 1000;
67+
auto pool =
68+
std::make_shared<VisitedListPool>(init_size, allocator.get(), vl_size, allocator.get());
69+
70+
auto TestVL = [&](std::shared_ptr<VisitedList>& vl_ptr) {
71+
int count = 500;
72+
std::unordered_set<InnerIdType> ids;
73+
for (int i = 0; i < count; ++i) {
74+
auto id = random() % vl_size;
75+
ids.insert(id);
76+
vl_ptr->Set(id);
77+
}
78+
for (auto& id : ids) {
79+
REQUIRE(vl_ptr->Get(id) == true);
80+
}
81+
82+
for (InnerIdType i = 0; i < vl_size; ++i) {
83+
if (ids.count(i) == 0) {
84+
REQUIRE(vl_ptr->Get(i) == false);
85+
}
86+
}
87+
};
88+
89+
SECTION("test basic") {
90+
std::vector<std::shared_ptr<VisitedList>> lists;
91+
REQUIRE(pool->GetSize() == init_size);
92+
lists.reserve(init_size * 2);
93+
for (auto i = 0; i < init_size * 2; ++i) {
94+
lists.emplace_back(pool->TakeOne());
95+
}
96+
REQUIRE(pool->GetSize() == 0);
97+
for (auto& ptr : lists) {
98+
pool->ReturnOne(ptr);
99+
}
100+
REQUIRE(pool->GetSize() == init_size * 2);
101+
102+
auto ptr = pool->TakeOne();
103+
REQUIRE(pool->GetSize() == init_size * 2 - 1);
104+
TestVL(ptr);
105+
}
106+
107+
SECTION("test concurrency") {
108+
auto func = [&]() {
109+
int count = 10;
110+
int max_operators = 20;
111+
std::vector<std::shared_ptr<VisitedList>> results;
112+
for (int i = 0; i < count; ++i) {
113+
auto opt = random() % max_operators + 1;
114+
for (auto j = 0; j < opt; ++j) {
115+
results.emplace_back(pool->TakeOne());
116+
}
117+
for (auto& result : results) {
118+
pool->ReturnOne(result);
119+
}
120+
results.clear();
121+
}
122+
};
123+
std::vector<std::shared_ptr<std::thread>> ths;
124+
auto thread_count = 5;
125+
ths.reserve(thread_count);
126+
for (auto i = 0; i < thread_count; ++i) {
127+
ths.emplace_back((std::make_shared<std::thread>(func)));
128+
}
129+
for (auto& thread : ths) {
130+
thread->join();
131+
}
132+
for (int i = 0; i < 10; ++i) {
133+
auto vl = pool->TakeOne();
134+
TestVL(vl);
135+
pool->ReturnOne(vl);
136+
}
137+
}
138+
}

0 commit comments

Comments
 (0)