Model Serving
Repository ยท Notebook
Subscribe to our newsletter
๐ฌ Receive new lessons straight to your inbox (once a month) and join 40K+ developers in learning how to responsibly deliver value with ML.
Intuition
In this lesson, we're going to serve the machine learning models that we have developed so that we can use them to make predictions on unseen data. And we want to be able to serve our models in a scalable and robust manner so it can deliver high throughput (handle many requests) and low latency (quickly respond to each request). In an effort to be comprehensive, we will implement both batch inference (offline) and online inference (real-time), though we will focus on the latter in the remaining lessons as it's more appropriate for our application.
Frameworks
There are many frameworks to choose from when it comes to model serving, such as Ray Serve, Nvidia Triton, HuggingFace, Bento ML, etc. When choosing between these frameworks, we want to choose the option that will allow us to:
- Pythonic: we don't want to learn a new framework to be able to serve our models.
- framework agnostic: we want to be able to serve models from all frameworks (PyTorch, TensorFlow, etc.)
- scale: (auto)scaling our service should be as easy as changing a configuration.
- composition: combine multiple models and business logic into our service.
- integrations: integrate with popular API frameworks like FastAPI.
To address all of these requirements (and more), we will be using Ray Serve to create our service. While we'll be specifically using it's integration with FastAPI, there are many other integrations you might want to explore based on your stack (LangChain, Kubernetes, etc.).
Batch inference
We will first implement batch inference (or offline inference), which is when we make predictions on a large batch of data. This is useful when we don't need to serve a model's prediction on input data as soon as the input data is received. For example, our service can be used to make predictions once at the end of every day on the batches of content collected throughout the day. This can be more efficient than making predictions on each content individually if we don't need that kind of low latency.
Let's take a look at our how we can easily implement batch inference with Ray Serve. We'll start with some setup and load the best checkpoint from our training run.
1 2 3 |
|
1 2 3 |
|
Next, we'll define a Predictor
class that will load the model from our checkpoint and then define the __call__
method that will be used to make predictions on our input data.
1 2 3 4 5 6 7 8 |
|
The
__call__
function in Python defines the logic that will be executed when our object is called like a function.
1 2predictor = Predictor() prediction = predictor(batch)
To do batch inference, we'll be using the map_batches
functionality. We previously used map_batches
to map
(or apply) a preprocessing function across batches
(chunks) of our data. We're now using the same concept to apply our predictor across batches of our inference data.
1 2 3 4 5 6 7 |
|
Note that
best_checkpoint
as a keyword argument to ourPredictor
class so that we can load the model from that checkpoint. We can pass this in via thefn_constructor_kwargs
argument in ourmap_batches
function.
1 2 |
|
[{'prediction': 'computer-vision'}, {'prediction': 'other'}, {'prediction': 'other'}]
Online inference
While we can achieve batch inference at scale, many models will need to be served in an real-time manner where we may need to deliver predictions for many incoming requests (high throughput) with low latency. We want to use online inference for our application over batch inference because we want to quickly categorize content as they are received/submitted to our platform so that the community can discover them quickly.
1 2 3 4 |
|
We'll start by defining our FastAPI application which involves initializing a predictor (and preprocessor) from the best checkpoint for a particular run (specified by run_id
). We'll also define a predict
function that will be used to make predictions on our input data.
1 2 3 4 5 |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
|
async def
refers to an asynchronous function (when we call the function we don't have to wait for the function to complete executing). Theawait
keyword is used inside an asynchronous function to wait for the completion of therequest.json()
operation.
We can now combine our FastAPI application with Ray Serve by simply wrapping our application with the serve.ingress
decorator. We can further wrap all of this with the serve.deployment
decorator to define our deployment configuration (ex. number of replicas, compute resources, etc.). These configurations allow us to easily scale our service as needed.
1 2 3 4 |
|
Now let's run our service and perform some real-time inference.
1 2 3 4 |
|
Started detached Serve instance in namespace "serve". Deployment 'default_ModelDeployment:IcuFap' is ready at `http://127.0.0.1:8000/`. component=serve deployment=default_ModelDeployment RayServeSyncHandle(deployment='default_ModelDeployment')
1 2 3 4 5 |
|
{'results': [{'prediction': 'natural-language-processing', 'probabilities': {'computer-vision': 0.00038025027606636286, 'mlops': 0.0003820903366431594, 'natural-language-processing': 0.9987919926643372, 'other': 0.00044562897528521717}}]}
The issue with neural networks (and especially LLMs) is that they are notoriously overconfident. For every input, they will always make some prediction. And to account for this, we have an other
class but that class only has projects that are not in our accepted tags but are still machine learning related nonetheless. Here's what happens when we input complete noise as our input:
1 2 3 4 |
|
{'results': [{'prediction': 'natural-language-processing', 'probabilities': {'computer-vision': 0.11885979026556015, 'mlops': 0.09778415411710739, 'natural-language-processing': 0.6735526323318481, 'other': 0.1098034456372261}}]}
Let's shutdown our service before we fixed this issue.
1 2 |
|
Custom logic
To make our service a bit more robust, let's add some custom logic to predict the other
class if the probability of the predicted class is below a certain threshold
probability.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
|
Tip
It's easier to incorporate custom logic instead of altering the model itself. This way, we won't have to collect new data. change the model's architecture or retrain it. This also makes it really easy to change the custom logic as our product specifications may change (clean separation of product and machine learning).
1 2 |
|
Started detached Serve instance in namespace "serve". Deployment 'default_ModelDeploymentRobust:RTbrNg' is ready at `http://127.0.0.1:8000/`. component=serve deployment=default_ModelDeploymentRobust RayServeSyncHandle(deployment='default_ModelDeploymentRobust')
Now let's see how we perform on the same random noise with our custom logic incorporate into the service.
1 2 3 4 |
|
{'results': [{'prediction': 'other', 'probabilities': {'computer-vision': 0.11885979026556015, 'mlops': 0.09778415411710739, 'natural-language-processing': 0.6735526323318481, 'other': 0.1098034456372261}}]}
1 2 |
|
We'll learn how to deploy our service to production in our Jobs and Services lesson a bit later.
Upcoming live cohorts
Sign up for our upcoming live cohort, where we'll provide live lessons + QA, compute (GPUs) and community to learn everything in one day.
To cite this content, please use:
1 2 3 4 5 6 |
|