Pipelines for webserver inference
Using pipelines for a webserver
Creating an inference engine is a complex topic, and the "best" solution will most likely depend on your problem space. Are you on CPU or GPU? Do you want the lowest latency, the highest throughput, support for many models, or just highly optimize 1 specific model? There are many ways to tackle this topic, so what we are going to present is a good default to get started which may not necessarily be the most optimal solution for you.
The key thing to understand is that we can use an iterator, just like you would on a dataset, since a webserver is basically a system that waits for requests and treats them as they come in.
Usually webservers are multiplexed (multithreaded, async, etc..) to handle various requests concurrently. Pipelines on the other hand (and mostly the underlying models) are not really great for parallelism; they take up a lot of RAM, so itβs best to give them all the available resources when they are running or itβs a compute-intensive job.
We are going to solve that by having the webserver handle the light load of receiving and sending requests, and having a single thread handling the actual work. This example is going to use starlette
. The actual framework is not really important, but you might have to tune or change the code if you are using another one to achieve the same effect.
Create server.py
:
Copied
Now you can start it with:
Copied
And you can query it:
Copied
And there you go, now you have a good idea of how to create a webserver!
What is really important is that we load the model only once, so there are no copies of the model on the webserver. This way, no unnecessary RAM is being used. Then the queuing mechanism allows you to do fancy stuff like maybe accumulating a few items before inferring to use dynamic batching:
The code sample below is intentionally written like pseudo-code for readability. Do not run this without checking if it makes sense for your system resources!
Copied
Again, the proposed code is optimized for readability, not for being the best code. First of all, thereβs no batch size limit which is usually not a great idea. Next, the timeout is reset on every queue fetch, meaning you could wait much more than 1ms before running the inference (delaying the first request by that much).
It would be better to have a single 1ms deadline.
This will always wait for 1ms even if the queue is empty, which might not be the best since you probably want to start doing inference if thereβs nothing in the queue. But maybe it does make sense if batching is really crucial for your use case. Again, thereβs really no one best solution.
Few things you might want to consider
Error checking
Thereβs a lot that can go wrong in production: out of memory, out of space, loading the model might fail, the query might be wrong, the query might be correct but still fail to run because of a model misconfiguration, and so on.
Generally, itβs good if the server outputs the errors to the user, so adding a lot of try..except
statements to show those errors is a good idea. But keep in mind it may also be a security risk to reveal all those errors depending on your security context.
Circuit breaking
Webservers usually look better when they do circuit breaking. It means they return proper errors when theyβre overloaded instead of just waiting for the query indefinitely. Return a 503 error instead of waiting for a super long time or a 504 after a long time.
This is relatively easy to implement in the proposed code since there is a single queue. Looking at the queue size is a basic way to start returning errors before your webserver fails under load.
Blocking the main thread
Currently PyTorch is not async aware, and computation will block the main thread while running. That means it would be better if PyTorch was forced to run on its own thread/process. This wasnβt done here because the code is a lot more complex (mostly because threads and async and queues donβt play nice together). But ultimately it does the same thing.
This would be important if the inference of single items were long (> 1s) because in this case, it means every query during inference would have to wait for 1s before even receiving an error.
Dynamic batching
In general, batching is not necessarily an improvement over passing 1 item at a time (see batching details for more information). But it can be very effective when used in the correct setting. In the API, there is no dynamic batching by default (too much opportunity for a slowdown). But for BLOOM inference - which is a very large model - dynamic batching is essential to provide a decent experience for everyone.
Last updated