Skip to content

Commit e71d594

Browse files
nishadsingh1dcrankshaw
authored andcommitted
Allow model versions to be strings (#197)
* Versions can be strings * Reformatted * Fixed some tests * Minor changes in management library, now check strings to be grouped for invalid characters * VersionedModelId as a class * Temporary commit to show hash problem * Partial fix * Functional * Formatted * Fixup * Fix failing tests * Addressed comments
1 parent 5a7d172 commit e71d594

25 files changed

Lines changed: 497 additions & 296 deletions

clipper_admin/clipper_manager.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def deploy_model(self,
431431
----------
432432
name : str
433433
The name to assign this model.
434-
version : int
434+
version : Any object with a string representation (with __str__ implementation)
435435
The version to assign this model.
436436
model_data : str or BaseEstimator
437437
The trained model to add to Clipper. This can either be a
@@ -470,6 +470,7 @@ def deploy_model(self,
470470
warn("%s is invalid model format" % str(type(model_data)))
471471
return False
472472

473+
version = str(version)
473474
vol = "{model_repo}/{name}/{version}".format(
474475
model_repo=MODEL_REPO, name=name, version=version)
475476
# publish model to Clipper and verify success before copying model
@@ -509,13 +510,14 @@ def register_external_model(self,
509510
----------
510511
name : str
511512
The name to assign this model.
512-
version : int
513+
version : Any object with a string representation (with __str__ implementation)
513514
The version to assign this model.
514515
input_type : str
515516
One of "integers", "floats", "doubles", "bytes", or "strings".
516517
labels : list of str, optional
517518
A list of strings annotating the model.
518519
"""
520+
version = str(version)
519521
return self._publish_new_model(name, version, labels, input_type,
520522
EXTERNALLY_MANAGED_MODEL,
521523
EXTERNALLY_MANAGED_MODEL)
@@ -586,7 +588,7 @@ def deploy_pyspark_model(self,
586588
----------
587589
name : str
588590
The name to assign this model.
589-
version : int
591+
version : Any object with a string representation (with __str__ implementation)
590592
The version to assign this model.
591593
predict_function : function
592594
A function that takes three arguments, a SparkContext, the ``model`` parameter and
@@ -679,7 +681,7 @@ def deploy_predict_function(self,
679681
----------
680682
name : str
681683
The name to assign this model.
682-
version : int
684+
version : Any object with a string representation (with __str__ implementation)
683685
The version to assign this model.
684686
predict_function : function
685687
The prediction function. Any state associated with the function should be
@@ -766,7 +768,7 @@ def get_model_info(self, model_name, model_version):
766768
----------
767769
model_name : str
768770
The name of the model to look up
769-
model_version : int
771+
model_version : Any object with a string representation (with __str__ implementation)
770772
The version of the model to look up
771773
772774
Returns
@@ -776,6 +778,7 @@ def get_model_info(self, model_name, model_version):
776778
If no model with name `model_name@model_version` is
777779
registered with Clipper, None is returned.
778780
"""
781+
model_version = str(model_version)
779782
url = "http://%s:1338/admin/get_model" % self.host
780783
req_json = json.dumps({
781784
"model_name": model_name,
@@ -826,7 +829,7 @@ def get_container_info(self, model_name, model_version, replica_id):
826829
----------
827830
model_name : str
828831
The name of the container to look up
829-
model_version : int
832+
model_version : Any object with a string representation (with __str__ implementation)
830833
The version of the container to look up
831834
replica_id : int
832835
The container replica to look up
@@ -837,6 +840,7 @@ def get_container_info(self, model_name, model_version, replica_id):
837840
A dictionary with the specified container's info.
838841
If no corresponding container is registered with Clipper, None is returned.
839842
"""
843+
model_version = str(model_version)
840844
url = "http://%s:1338/admin/get_container" % self.host
841845
req_json = json.dumps({
842846
"model_name": model_name,
@@ -970,7 +974,7 @@ def add_container(self, model_name, model_version):
970974
----------
971975
model_name : str
972976
The name of the model
973-
model_version : int
977+
model_version : Any object with a string representation (with __str__ implementation)
974978
The version of the model
975979
976980
Returns
@@ -979,6 +983,7 @@ def add_container(self, model_name, model_version):
979983
True if the container was added successfully and False
980984
if the container could not be added.
981985
"""
986+
model_version = str(model_version)
982987
with hide("warnings", "output", "running"):
983988
# Look up model info in Redis
984989
if self.redis_ip == DEFAULT_REDIS_IP:
@@ -1024,7 +1029,7 @@ def add_container(self, model_name, model_version):
10241029
mv=model_version,
10251030
mip=model_input_type,
10261031
clipper_label=CLIPPER_DOCKER_LABEL,
1027-
mv_label="%s=%s:%d" % (CLIPPER_MODEL_CONTAINER_LABEL,
1032+
mv_label="%s=%s:%s" % (CLIPPER_MODEL_CONTAINER_LABEL,
10281033
model_name, model_version),
10291034
restart_policy=restart_policy))
10301035
result = self._execute_root(add_container_cmd)
@@ -1101,14 +1106,16 @@ def set_model_version(self, model_name, model_version, num_containers=0):
11011106
----------
11021107
model_name : str
11031108
The name of the model
1104-
model_version : int
1109+
model_version : Any object with a string representation (with __str__ implementation)
11051110
The version of the model. Note that `model_version`
11061111
must be a model version that has already been deployed.
11071112
num_containers : int
11081113
The number of new containers to start with the newly
11091114
selected model version.
11101115
11111116
"""
1117+
model_version = str(model_version)
1118+
11121119
url = "http://%s:%d/admin/set_model_version" % (
11131120
self.host, CLIPPER_MANAGEMENT_PORT)
11141121
req_json = json.dumps({

integration-tests/clipper_manager_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_model_version_sets_correctly(self):
105105
models_list_contains_correct_version = False
106106
for model_info in all_models:
107107
version = model_info["model_version"]
108-
if version == self.model_version_1:
108+
if version == str(self.model_version_1):
109109
models_list_contains_correct_version = True
110110
self.assertTrue(model_info["is_current_version"])
111111

src/benchmarks/src/end_to_end_bench.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ void send_predictions(
9898
cifar_input,
9999
100000,
100100
clipper::DefaultOutputSelectionPolicy::get_name(),
101-
{std::make_pair(SKLEARN_MODEL_NAME, 1)}});
101+
{VersionedModelId(SKLEARN_MODEL_NAME, "1")}});
102102
futures.push_back(std::move(future));
103103
}
104104

src/frontends/src/query_frontend.hpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,16 @@ class RequestHandler {
178178
event_type);
179179
if (event_type == "set") {
180180
std::string model_name = key;
181-
int new_version = clipper::redis::get_current_model_version(
182-
redis_connection_, key);
183-
if (new_version >= 0) {
181+
boost::optional<std::string> new_version =
182+
clipper::redis::get_current_model_version(redis_connection_,
183+
key);
184+
if (new_version) {
184185
std::unique_lock<std::mutex> l(current_model_versions_mutex_);
185-
current_model_versions_[key] = new_version;
186+
current_model_versions_[key] = *new_version;
186187
} else {
187188
clipper::log_error_formatted(
188189
LOGGING_TAG_QUERY_FRONTEND,
189-
"Model version change for model {} was invalid (-1).", key);
190+
"Model version change for model {} was invalid.", key);
190191
}
191192
}
192193
});
@@ -223,18 +224,16 @@ class RequestHandler {
223224
for (std::string model_name : model_names) {
224225
auto model_version = clipper::redis::get_current_model_version(
225226
redis_connection_, model_name);
226-
if (model_version >= 0) {
227+
if (model_version) {
227228
std::unique_lock<std::mutex> l(current_model_versions_mutex_);
228-
current_model_versions_[model_name] = model_version;
229+
current_model_versions_[model_name] = *model_version;
230+
model_names_with_version.push_back(model_name + "@" + *model_version);
229231
} else {
230232
clipper::log_error_formatted(
231233
LOGGING_TAG_QUERY_FRONTEND,
232-
"Found model {} with invalid version number {}.", model_name,
233-
model_version);
234-
throw std::runtime_error("Invalid model version number");
234+
"Found model {} with missing current version.", model_name);
235+
throw std::runtime_error("Invalid model version");
235236
}
236-
model_names_with_version.push_back(model_name + "@v" +
237-
std::to_string(model_version));
238237
}
239238
if (model_names.size() > 0) {
240239
clipper::log_info_formatted(LOGGING_TAG_QUERY_FRONTEND,
@@ -496,7 +495,7 @@ class RequestHandler {
496495
/**
497496
* Returns a copy of the map containing current model names and versions.
498497
*/
499-
std::unordered_map<std::string, int> get_current_model_versions() {
498+
std::unordered_map<std::string, std::string> get_current_model_versions() {
500499
return current_model_versions_;
501500
}
502501

@@ -506,7 +505,7 @@ class RequestHandler {
506505
redox::Redox redis_connection_;
507506
redox::Subscriber redis_subscriber_;
508507
std::mutex current_model_versions_mutex_;
509-
std::unordered_map<std::string, int> current_model_versions_;
508+
std::unordered_map<std::string, std::string> current_model_versions_;
510509
};
511510

512511
} // namespace query_frontend

src/frontends/src/query_frontend_tests.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class MockQueryProcessor {
1919
public:
2020
MockQueryProcessor() = default;
2121
boost::future<Response> predict(Query query) {
22-
Response response(query, 3, 5, Output("-1.0", {std::make_pair("m", 1)}),
22+
Response response(query, 3, 5, Output("-1.0", {VersionedModelId("m", "1")}),
2323
false, boost::optional<std::string>{});
2424
return boost::make_ready_future(response);
2525
}
@@ -280,24 +280,25 @@ TEST_F(QueryFrontendTest, TestReadModelsAtStartup) {
280280
// Add multiple models (some with multiple versions)
281281
std::vector<std::string> labels{"ads", "images", "experimental", "other",
282282
"labels"};
283-
VersionedModelId model1 = std::make_pair("m", 1);
283+
VersionedModelId model1 = VersionedModelId("m", "1");
284284
std::string container_name = "clipper/test_container";
285285
std::string model_path = "/tmp/models/m/1";
286286
ASSERT_TRUE(add_model(*redis_, model1, InputType::Ints, labels,
287287
container_name, model_path));
288-
VersionedModelId model2 = std::make_pair("m", 2);
288+
VersionedModelId model2 = VersionedModelId("m", "2");
289289
std::string model_path2 = "/tmp/models/m/2";
290290
ASSERT_TRUE(add_model(*redis_, model2, InputType::Ints, labels,
291291
container_name, model_path2));
292-
VersionedModelId model3 = std::make_pair("n", 3);
292+
VersionedModelId model3 = VersionedModelId("n", "3");
293293
std::string model_path3 = "/tmp/models/n/3";
294294
ASSERT_TRUE(add_model(*redis_, model3, InputType::Ints, labels,
295295
container_name, model_path3));
296296

297297
// Set m@v2 and n@v3 as current model versions
298-
set_current_model_version(*redis_, "m", 2);
299-
set_current_model_version(*redis_, "n", 3);
300-
std::unordered_map<std::string, int> expected_models = {{"m", 2}, {"n", 3}};
298+
set_current_model_version(*redis_, "m", "2");
299+
set_current_model_version(*redis_, "n", "3");
300+
std::unordered_map<std::string, std::string> expected_models = {{"m", "2"},
301+
{"n", "3"}};
301302

302303
RequestHandler<MockQueryProcessor> rh2_("127.0.0.1", 1337, 8);
303304
EXPECT_EQ(rh2_.get_current_model_versions(), expected_models);
@@ -306,7 +307,7 @@ TEST_F(QueryFrontendTest, TestReadModelsAtStartup) {
306307
TEST_F(QueryFrontendTest, TestReadInvalidModelVersionAtStartup) {
307308
std::vector<std::string> labels{"ads", "images", "experimental", "other",
308309
"labels"};
309-
VersionedModelId model1 = std::make_pair("m", 1);
310+
VersionedModelId model1 = VersionedModelId("m", "1");
310311
std::string container_name = "clipper/test_container";
311312
std::string model_path = "/tmp/models/m/1";
312313
ASSERT_TRUE(add_model(*redis_, model1, InputType::Ints, labels,

src/libclipper/include/clipper/containers.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ class ActiveContainers {
8484
// A mapping of models to their replicas. The replicas
8585
// for each model are represented as a map keyed on replica id.
8686
std::unordered_map<VersionedModelId,
87-
std::map<int, std::shared_ptr<ModelContainer>>,
88-
decltype(&versioned_model_hash)>
87+
std::map<int, std::shared_ptr<ModelContainer>>>
8988
containers_;
9089
};
9190
}

src/libclipper/include/clipper/datatypes.hpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
#include <string>
77
#include <vector>
88

9+
#include <boost/functional/hash.hpp>
910
#include <boost/optional.hpp>
11+
#include <boost/thread.hpp>
1012

1113
namespace clipper {
1214

1315
using ByteBuffer = std::vector<uint8_t>;
14-
using VersionedModelId = std::pair<std::string, int>;
1516
using QueryId = long;
1617
using FeedbackAck = bool;
1718

@@ -28,11 +29,32 @@ enum class RequestType {
2829
FeedbackRequest = 1,
2930
};
3031

31-
size_t versioned_model_hash(const VersionedModelId &key);
32-
std::string versioned_model_to_str(const VersionedModelId &model);
3332
std::string get_readable_input_type(InputType type);
3433
InputType parse_input_type(std::string type_string);
3534

35+
class VersionedModelId {
36+
public:
37+
VersionedModelId(const std::string name, const std::string id);
38+
39+
std::string get_name() const;
40+
std::string get_id() const;
41+
std::string serialize() const;
42+
static VersionedModelId deserialize(std::string);
43+
44+
VersionedModelId(const VersionedModelId &) = default;
45+
VersionedModelId &operator=(const VersionedModelId &) = default;
46+
47+
VersionedModelId(VersionedModelId &&) = default;
48+
VersionedModelId &operator=(VersionedModelId &&) = default;
49+
50+
bool operator==(const VersionedModelId &rhs) const;
51+
bool operator!=(const VersionedModelId &rhs) const;
52+
53+
private:
54+
std::string name_;
55+
std::string id_;
56+
};
57+
3658
class Output {
3759
public:
3860
Output(const std::string y_hat,
@@ -384,5 +406,16 @@ class PredictionResponse {
384406
} // namespace rpc
385407

386408
} // namespace clipper
387-
409+
namespace std {
410+
template <>
411+
struct hash<clipper::VersionedModelId> {
412+
typedef std::size_t result_type;
413+
std::size_t operator()(const clipper::VersionedModelId &vm) const {
414+
std::size_t seed = 0;
415+
boost::hash_combine(seed, vm.get_name());
416+
boost::hash_combine(seed, vm.get_id());
417+
return seed;
418+
}
419+
};
420+
}
388421
#endif // CLIPPER_LIB_DATATYPES_H

src/libclipper/include/clipper/redis.hpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,23 @@ namespace redis {
2121

2222
const std::string LOGGING_TAG_REDIS = "REDIS";
2323

24+
/**
25+
* Elements of this vector should not appear as substrings of any input Clipper
26+
* object value that will be grouped into one entry.
27+
* This list should be updated to reflect all delimiters and characters added in
28+
* `labels_to_str` or `models_to_str`.
29+
*/
30+
const std::vector<std::string> prohibited_group_strings = {
31+
ITEM_DELIMITER, ITEM_PART_CONCATENATOR};
32+
33+
/**
34+
* Use this function to validate inputs that will be grouped before submitting
35+
* them to functions in this library.
36+
* @return Whether or not `value` contains any elements of `probhited_strings`
37+
* as substrings.
38+
*/
39+
bool contains_prohibited_chars_for_group(std::string value);
40+
2441
/**
2542
* Issues a command to Redis and checks return code.
2643
* \return Returns true if the command was successful.
@@ -91,10 +108,11 @@ std::string models_to_str(const std::vector<VersionedModelId>& models);
91108
std::vector<VersionedModelId> str_to_models(const std::string& model_str);
92109

93110
bool set_current_model_version(redox::Redox& redis,
94-
const std::string& model_name, int version);
111+
const std::string& model_name,
112+
const std::string& version);
95113

96-
int get_current_model_version(redox::Redox& redis,
97-
const std::string& model_name);
114+
boost::optional<std::string> get_current_model_version(
115+
redox::Redox& redis, const std::string& model_name);
98116

99117
/**
100118
* Adds a model into the model table. This will
@@ -142,8 +160,8 @@ std::unordered_map<std::string, std::string> get_model(
142160
* \return Returns a list of model versions. If the
143161
* model was not found, an empty list will be returned.
144162
*/
145-
std::vector<int> get_model_versions(redox::Redox& redis,
146-
const std::string& model_name);
163+
std::vector<std::string> get_model_versions(redox::Redox& redis,
164+
const std::string& model_name);
147165

148166
/**
149167
* Looks up model names listed in the model table. Since a call to KEYS may

src/libclipper/include/clipper/rpc_service.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ class RPCService {
102102
std::atomic_bool active_;
103103
// The next available message id
104104
int message_id_ = 0;
105-
std::unordered_map<VersionedModelId, int, decltype(&versioned_model_hash)>
106-
replica_ids_;
105+
std::unordered_map<VersionedModelId, int> replica_ids_;
107106
std::shared_ptr<metrics::Histogram> msg_queueing_hist_;
108107

109108
std::function<void(VersionedModelId, int)> container_ready_callback_;

0 commit comments

Comments
 (0)