Support for PyTorch model#346
Conversation
|
Test FAILed. |
| @@ -1,15 +1,20 @@ | |||
| """ | |||
There was a problem hiding this comment.
Pulled the latest cloudpickle. This fixes issues with recursive dependencies that caused issues when serializing pytorch models.
There was a problem hiding this comment.
If we're replacing Cloudpickle, let's just switch to using the cloudpickle pip package (pip install cloudpickle).
| PYTORCH_MODEL_RELATIVE_PATH) | ||
|
|
||
| try: | ||
| torch.save(pytorch_model.state_dict(), torch_weights_save_loc) |
There was a problem hiding this comment.
We now save the model weights and class definition separately. Weights are saved via PyTorch, and the class is pickled with CloudPickle
|
Test PASSed. |
dcrankshaw
left a comment
There was a problem hiding this comment.
Two small changes then this is good to go. @Corey-Zumar can you go ahead and make those?
| @@ -1,15 +1,20 @@ | |||
| """ | |||
There was a problem hiding this comment.
If we're replacing Cloudpickle, let's just switch to using the cloudpickle pip package (pip install cloudpickle).
| && apt-get update --fix-missing \ | ||
| && apt-get install -yqq -t jessie-backports openjdk-8-jdk \ | ||
| && conda install -y --file /lib/python_container_conda_deps.txt \ | ||
| && conda install pytorch torchvision -c soumith |
There was a problem hiding this comment.
They created a PyTorch conda channel recently, so this should be conda install pytorch torchvision -c pytorch
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
|
jenkins test this please |
|
Test FAILed. |
|
Test PASSed. |
@haofanwang Added some updates to fix the class serialization issue and accidentally closed #322 in the process. Take a look and let me know if you have any questions.
This should be good to go once @dcrankshaw takes a final pass.
Fixes #343, #314