AWS recently released TorchServe, an open-source model serving library for PyTorch. The production-readiness of Tensorflow has long been one of its competitive advantages. TorchServe is PyTorch community’s response to that. It is supposed to be the PyTorch counterpart of Tensorflow Serving. So far, it seems to have a very strong start.
This post from the AWS Machine Learning Blog and the documentation of TorchServe should be more than enough to get you started. But for advanced usage, the documentation is a bit chaotic and the example code suggests sometimes conflicting ways to do things.
This post is not meant to be a tutorial for beginners. Instead, it uses a case study to show the readers what a slightly more complicated deployment looks like, and saves the readers’ time by referencing relevant documents and example code.
In this post,we will deploy an EfficientNet model from the rwightman/gen-efficientnet-pytorch repo. The server accepts images as arrays in Numpy binary format and returns the corresponding class probabilities. (The reason for using Numpy binary format is that in this use case the images are already read into memory on the client-side, the network bandwidth is cheap and we don’t have strict latency requirements, so re-encoded it into JPEG or PNG format doesn’t make sense.)
Preparing the EfficientNet Model
TorchServe can load models from PyTorch checkpoints (
.state_dict()) or exported TorchScript programs. I’d recommend using TorchScript when possible, as it doesn’t require you to install extra libraries (e.g., gen-efficientnet-pytorch) and provide a model definition file.
rwightman/gen-efficientnet-pytorch already provides an easy API to create TorchScript-compatible models. In this case, the model has already been trained and saved via
torch.save(model). We need to load it using
torch.load, create an untrained TorchScript-compatible model, and transfer the weights:
geffnet.config.set_scriptable(True) model_old = torch.load( "cache/b4-checkpoint.pth" )["model"].cpu() model, _ = get_model( arch="b4", n_classes=6 ) model.load_state_dict(model_old.state_dict()) del model_old with torch.jit.optimized_execution(True): model = torch.jit.script(model) model.save("cache/b4.pt")
(The above code was inspired by this script in the gen-efficientnet-pytorch repo.)
Please note that the
geffnet.config.set_scriptable(True) line is essential. Without it the model won’t be able to be compiled with TorchScript.
The Custom Handler
TorchServe comes with four default handlers that define the input and output of the deployed service. We are deploying an image classification model in this example, and the corresponding default handler is
image_classifier. If you read its source code, you’ll find that it accepts a binary image input, resize, center crop, and normalize it, and returns the top 5 predicted classes. Most of these don’t fit our use case, so we’ll have to write our own handler. You can refer to this documentation on how to create non-standard services.
In this example, we’ll have thousands of images per minute to predict, so batch processing is essential. For more information on batch inference with TorchServe, please refer to this documentation.
The code is based on the resnet_152_batch example, with some simplification (e.g., we don’t need to handle PyTorch checkpoints). By the way, the MNIST example used a confusing way to load model and model file, the one in resnet_152_batch makes much more sense (by using the
- Load the model from TorchScript program (Line 30).
- Load the image from an array in numpy binary format:
input_image = Image.fromarray(np.load(io.BytesIO(image)))(Line 59).
- Test Time Augmentation (TTA) in Line 48 to 53, Line 60 to 61, and Line 84 to 87 (horizontal flip).
This handler cannot handle malformed inputs (as all the example handlers I’ve seen). If that’s inevitable in your use case, you’ll probably need to find some way to identify those inputs, ignore that in the
inference method, and return proper error messages in the
Create the Model Archive
TorchServe requires the user to package all model artifacts into a single model archive file. It’s fairly straight-forward in our case. Please refer to the documentation if in doubt.
mkdir model-store torch-model-archiver --model-name b4 --version 1.0 \ --serialized-file cache/b4.pt --handler handler.py \ --export-path model-store
Start the TorchServe service
I use a shell script to start the TorchServe server and register the model:
torchserve --start --model-store model-store --ts-config config.properties > /dev/null sleep 3 curl -X DELETE http://localhost:8081/models/b4 curl -X POST "localhost:8081/models?model_name=b4&url=b4.mar&batch_size=4&max_batch_delay=1000&initial_workers=1&synchronous=true"
batch_size = 4: because of the TTA, the effective batch size is actually 8. I’ve noticed that the maximum batch size is smaller in TorchServe than directly do the inference in a Python script. I’m not sure the reason why this is the case.
max_batch_delay=1000: wait at most 1 second for the batch to be filled. You can adjust this according to your latency requirements.
config.properties file contains:
log4j.properties is an exact copy of the one used by default.
At this point, your TorchServe service should be up and running.
If you updated your model, create the model archive at the same path, and rerun the shell script will automatically reload the model on the server.
Stop the server by running
Client Requests Example
Here’s an example of making multiple requests to the server via asyncio:
async def predict_batch(cache): buffer =  for img in cache["images"]: output = io.BytesIO() np.save(output, img) output.seek(0, 0) buffer.append(output) loop = asyncio.get_event_loop() executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) responses = await asyncio.gather(*[ loop.run_in_executor( executor, requests.post, INFERENCE_ENDPOINT, pickled_image ) for pickled_image in buffer ]) probs = [res.json() for res in responses] return probs
Thanks for reading! I hope this post makes it easier for you to understand and use TorchServe. TorchServe creates an API for your model and does most of the heavy-lifting involved in handling HTTP requests. It shows great promise in the production environment support of PyTorch.
If you have any suggestions, please feel free to leave a comment.