|
| 1 | +from __future__ import print_function, with_statement, absolute_import |
| 2 | +import shutil |
| 3 | +import tensorflow as tf |
| 4 | +import logging |
| 5 | +import re |
| 6 | +import os |
| 7 | +import json |
| 8 | + |
| 9 | +from ..version import __version__ |
| 10 | +from ..clipper_admin import ClipperException |
| 11 | +from .deployer_utils import save_python_function |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +def create_endpoint(clipper_conn, |
| 17 | + name, |
| 18 | + input_type, |
| 19 | + func, |
| 20 | + tf_sess, |
| 21 | + default_output="None", |
| 22 | + version=1, |
| 23 | + slo_micros=3000000, |
| 24 | + labels=None, |
| 25 | + registry=None, |
| 26 | + base_image="clipper/tf-container:{}".format(__version__), |
| 27 | + num_replicas=1): |
| 28 | + """Registers an app and deploys the provided predict function with TensorFlow model as |
| 29 | + a Clipper model. |
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + clipper_conn : :py:meth:`clipper_admin.ClipperConnection` |
| 34 | + A ``ClipperConnection`` object connected to a running Clipper cluster. |
| 35 | + name : str |
| 36 | + The name to be assigned to both the registered application and deployed model. |
| 37 | + input_type : str |
| 38 | + The input_type to be associated with the registered app and deployed model. |
| 39 | + One of "integers", "floats", "doubles", "bytes", or "strings". |
| 40 | + func : function |
| 41 | + The prediction function. Any state associated with the function will be |
| 42 | + captured via closure capture and pickled with Cloudpickle. |
| 43 | + tf_sess : The Tensorflow Session to save. |
| 44 | + default_output : str, optional |
| 45 | + The default output for the application. The default output will be returned whenever |
| 46 | + an application is unable to receive a response from a model within the specified |
| 47 | + query latency SLO (service level objective). The reason the default output was returned |
| 48 | + is always provided as part of the prediction response object. Defaults to "None". |
| 49 | + version : str, optional |
| 50 | + The version to assign this model. Versions must be unique on a per-model |
| 51 | + basis, but may be re-used across different models. |
| 52 | + slo_micros : int, optional |
| 53 | + The query latency objective for the application in microseconds. |
| 54 | + This is the processing latency between Clipper receiving a request |
| 55 | + and sending a response. It does not account for network latencies |
| 56 | + before a request is received or after a response is sent. |
| 57 | + If Clipper cannot process a query within the latency objective, |
| 58 | + the default output is returned. Therefore, it is recommended that |
| 59 | + the SLO not be set aggressively low unless absolutely necessary. |
| 60 | + 100000 (100ms) is a good starting value, but the optimal latency objective |
| 61 | + will vary depending on the application. |
| 62 | + labels : list(str), optional |
| 63 | + A list of strings annotating the model. These are ignored by Clipper |
| 64 | + and used purely for user annotations. |
| 65 | + registry : str, optional |
| 66 | + The Docker container registry to push the freshly built model to. Note |
| 67 | + that if you are running Clipper on Kubernetes, this registry must be accesible |
| 68 | + to the Kubernetes cluster in order to fetch the container from the registry. |
| 69 | + base_image : str, optional |
| 70 | + The base Docker image to build the new model image from. This |
| 71 | + image should contain all code necessary to run a Clipper model |
| 72 | + container RPC client. |
| 73 | + num_replicas : int, optional |
| 74 | + The number of replicas of the model to create. The number of replicas |
| 75 | + for a model can be changed at any time with |
| 76 | + :py:meth:`clipper.ClipperConnection.set_num_replicas`. |
| 77 | + """ |
| 78 | + |
| 79 | + clipper_conn.register_application(name, input_type, default_output, |
| 80 | + slo_micros) |
| 81 | + deploy_tensorflow_model(clipper_conn, name, version, input_type, func, |
| 82 | + tf_sess, base_image, labels, registry, |
| 83 | + num_replicas) |
| 84 | + |
| 85 | + clipper_conn.link_model_to_app(name, name) |
| 86 | + |
| 87 | + |
| 88 | +def deploy_tensorflow_model( |
| 89 | + clipper_conn, |
| 90 | + name, |
| 91 | + version, |
| 92 | + input_type, |
| 93 | + func, |
| 94 | + tf_sess, |
| 95 | + base_image="clipper/tf-container:{}".format(__version__), |
| 96 | + labels=None, |
| 97 | + registry=None, |
| 98 | + num_replicas=1): |
| 99 | + """Deploy a Python prediction function with a Tensorflow model. |
| 100 | + Parameters |
| 101 | + ---------- |
| 102 | + clipper_conn : :py:meth:`clipper_admin.ClipperConnection` |
| 103 | + A ``ClipperConnection`` object connected to a running Clipper cluster. |
| 104 | + name : str |
| 105 | + The name to be assigned to both the registered application and deployed model. |
| 106 | + version : str |
| 107 | + The version to assign this model. Versions must be unique on a per-model |
| 108 | + basis, but may be re-used across different models. |
| 109 | + input_type : str |
| 110 | + The input_type to be associated with the registered app and deployed model. |
| 111 | + One of "integers", "floats", "doubles", "bytes", or "strings". |
| 112 | + func : function |
| 113 | + The prediction function. Any state associated with the function will be |
| 114 | + captured via closure capture and pickled with Cloudpickle. |
| 115 | + tf_sess : tensorflow.python.client.session.Session |
| 116 | + The tensor flow session to save. |
| 117 | + base_image : str, optional |
| 118 | + The base Docker image to build the new model image from. This |
| 119 | + image should contain all code necessary to run a Clipper model |
| 120 | + container RPC client. |
| 121 | + labels : list(str), optional |
| 122 | + A list of strings annotating the model. These are ignored by Clipper |
| 123 | + and used purely for user annotations. |
| 124 | + registry : str, optional |
| 125 | + The Docker container registry to push the freshly built model to. Note |
| 126 | + that if you are running Clipper on Kubernetes, this registry must be accesible |
| 127 | + to the Kubernetes cluster in order to fetch the container from the registry. |
| 128 | + num_replicas : int, optional |
| 129 | + The number of replicas of the model to create. The number of replicas |
| 130 | + for a model can be changed at any time with |
| 131 | + :py:meth:`clipper.ClipperConnection.set_num_replicas`. |
| 132 | +
|
| 133 | +
|
| 134 | + Example |
| 135 | + ------- |
| 136 | + from clipper_admin import ClipperConnection, DockerContainerManager |
| 137 | + from clipper_admin.deployers.tensorflow import deploy_tensorflow_model |
| 138 | +
|
| 139 | + clipper_conn = ClipperConnection(DockerContainerManager()) |
| 140 | +
|
| 141 | + # Connect to an already-running Clipper cluster |
| 142 | + clipper_conn.connect() |
| 143 | +
|
| 144 | + def predict(sess, inputs): |
| 145 | + preds = sess.run('predict_class:0', feed_dict={'pixels:0': inputs}) |
| 146 | + return [str(p) for p in preds] |
| 147 | +
|
| 148 | + deploy_tensorflow_model( |
| 149 | + clipper_conn, |
| 150 | + model_name, |
| 151 | + version, |
| 152 | + input_type, |
| 153 | + predict_fn, |
| 154 | + sess) |
| 155 | +
|
| 156 | + """ |
| 157 | + # save predict function |
| 158 | + serialization_dir = save_python_function(name, func) |
| 159 | + # save Tensorflow session |
| 160 | + tf_sess_save_loc = os.path.join(serialization_dir, "tfmodel/model.ckpt") |
| 161 | + try: |
| 162 | + saver = tf.train.Saver() |
| 163 | + save_path = saver.save(tf_sess, tf_sess_save_loc) |
| 164 | + except Exception as e: |
| 165 | + logger.warn("Error saving Tensorflow model: %s" % e) |
| 166 | + raise e |
| 167 | + |
| 168 | + logger.info("TensorFlow model saved at: %s " % save_path) |
| 169 | + |
| 170 | + # Deploy model |
| 171 | + clipper_conn.build_and_deploy_model(name, version, input_type, |
| 172 | + serialization_dir, base_image, labels, |
| 173 | + registry, num_replicas) |
| 174 | + |
| 175 | + # Remove temp files |
| 176 | + shutil.rmtree(serialization_dir) |
0 commit comments