While staying PEP 8 compliant
As I iterate over my models, they usually have a varying feature set. While testing I like to serve my models in a micro service for easy integration with my other services. Keeping the model-server up to date with the ever changing model has proven to be quite time consuming due to the changes in the features. I found myself spending more and more time tinkering with the model server than the actual model. Something needed to change.
FastAPI is a web framework for building APIs in python. It has become my go to framework when serving my models. It requires a minimal amount of code, which in turn decreases the time required for development and testing and also reduces the risk for errors.
Using FastAPI on uvicorn, I wrote a model-serving service which loads my XGBoost model, creates a class for a simple REST endpoint which validates input data defined by the model itself, performs the prediction and returns the result.
All this with automatic OpenAPI documentation and argument validation (courtesy of FastAPI) in less than 40 lines of python, while also staying PEP8 compliant.
import pickle from pydoc import locate from typing import List import numpy as np import uvicorn from fastapi import FastAPI from pydantic import BaseModel model = pickle.load(open("model.dat", "rb")) def create_type_instance(type_name: str): return locate(type_name).__call__() def get_features_dict(model): feature_names = model.get_booster().feature_names feature_types = list(map(create_type_instance, model.get_booster().feature_types)) return dict(zip(feature_names, feature_types)) def create_input_features_class(model): return type("InputFeatures", (BaseModel,), get_features_dict(model)) InputFeatures = create_input_features_class(model) app = FastAPI() @app.post("/predict", response_model=List) async def predict_post(data: List[InputFeatures]): return model.predict(np.asarray([list(data.__dict__.values()) for data in datas])).tolist() if __name__ == "__main__": print(get_features_dict(model)) uvicorn.run(app, host="0.0.0.0", port=8080)
To allow FastAPI to generate the OpenAPI doc and model validation, we must supply a type hint to the input argument of the function. On line 32 I define the input argument must be a List of InputFeatures.
InputFeatures is a class which I create dynamically during runtime based on my models feature set. In this example I use an XGBoost model, but any model can be used, as long as the get_features_dict function is adjusted accordingly.
To create a new Class i use the type function which either returns the type of the object or returns a new type object based on the arguments passed.
On line 24 I pass in three arguments, first the name of the class, in this case “InputFeatures”. Second I supply a tuple containing the Base classes for my new class. In this case i supply the pydantic BaseModel, which will be used by FastAPI for argument validation and OpenAPI docs. Lastly I supply a dict which contains the body definition of the class. This is created using the get_features_dict function.
get_features_dict
This function will vary depending on your model, in this example I’m using an XGBoost model which I load memory using pickle on line 10.
XGBoost has functions which allows me to extract both the feature names and a string value of the type of the feature.
I can use this string value, for example float and pass it to the locate function from the pydoc module, to get the class type of matching string. I then invoke __call__ () to create an instance of the type. I then zip the two lists together to create a dict.
For a model with 8 numerical features the resulting dict should look something like this
{ 'feat_0': 0.0, 'feat_1': 0.0, 'feat_2': 0.0, 'feat_3': 0.0, 'feat_4': 0.0, feat_5': 0.0, 'feat_6': 0.0, feat_7': 0.0 }
This allows the type function to correctly type the fields in our new class, which in turn allows FastAPI to generate argument validation and documentation!
Before the model can do a prediction, it must first get the data in a format it can handle. On line 33 I simply use the built in __dict__ () method to convert the InputFeature to a dict. I then call values( ) to get a the dict_values, which in turn is turned to a simple list, which can then be passed to numpys asarray function to convert our List of InputFeatures to a numpy ndarray which the model can handle.
To expose the model I define a REST endpoint on line 31-32 with the function predict_post. Using FastAPI annotations, I can define it as a POST endpoint with a path and a response type. The async keyword signals to the python interpreter that the function can run concurrently, allowing multiple requests to the service to be processed at the same time.
Last on line 36 I define the services main loop, which is just running a FastAPI app on uvicorn on port 8080.
Run the application and navigate to localhost:8080/doc to view the OpenAPI UI and test your brand new model server!
You can view the model for the InputFeatures and even try out the predict endpoint. If you supply broken or invalid data, you will get a pretty 422 response with detailed description on what went wrong.
Now imagine you do some more EDA and make some new discoveries. So you change the input feature set. Simply switch the model.dat file and restart the model server!
And just like that your model server is up to date with the new version of your model!
Summary
FastAPI enables you to easily, and rapidly expose any machine learning model for prediction in a scalable manner. In less than 40 lines of code. Reducing the development time, cost and risk of errors, allowing you to focus on model development rather than model serving.
You are still stuck with the issue of changing input features for the caller of the predict endpoint, but I’ll leave that problem for the next article!
More insights and blog posts
When we come across interesting technical things on our adventures, we usually write about them. Sharing is caring!
A summary of the most interesting AI Use Cases we have implemented.
Composable commerce creates the ability to meet changing customer expectations quickly and successfully.
Data Mesh is a strategy for scaling up your reporting and analysis capabilities. Learn more about the Google Cloud building blocks that enable your Data Mesh.