Integrating ML Models#
This document explains how to integrate a machine learning model into a FastAPI application. It includes a simple example to make it easy to understand. The goal is to create an API that receives input data, processes it through a pre-trained machine learning model, and returns predictions.
👉 New to App-Generator? Sign IN with GitHub or Generate Web Apps in no time (free service).
Scenario: Let’s consider a simple scenario where we want to predict whether a person has diabetes based on some medical features. We’ll use a pre-trained machine learning model (for simplicity, we will use scikit-learn).
- Steps:
Train a simple machine learning model.
Save the trained model to a file.
Create a FastAPI app to serve predictions from the model.
Prerequisites#
Python 3.7+
FastAPI
Uvicorn (for running FastAPI)
Scikit-learn
Joblib (for saving and loading the model)
Install dependencies:#
pip install fastapi[all] scikit-learn joblib uvicorn matplotlib
Step 1: Train a Simple Model#
We will start by training a simple machine learning model to predict whether a person has diabetes based on their medical features.
# train_model.py import joblib import matplotlib.pyplot as plt from sklearn.datasets import load_diabetes from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier # Load diabetes dataset diabetes_data = load_diabetes() # Visualize the distribution of the target variable plt.hist(diabetes_data.target, bins=50) plt.xlabel('Target Value') plt.ylabel('Frequency') plt.title('Distribution of Target Variable') plt.show() # Convert target into binary (1 = diabetic, 0 = not diabetic) y = diabetes_data.target > 100 # Split into training and testing sets X = diabetes_data.data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Train a RandomForestClassifier model = RandomForestClassifier() model.fit(X_train, y_train) # Save the trained model to a file joblib.dump(model, 'diabetes_model.joblib') print("Model trained and saved.")
Run this script to train and save the model:
python train_model.py
Step 2: Create FastAPI Application#
Now, we will create a FastAPI application that loads the saved model and exposes an endpoint to make predictions.
# main.py import joblib import numpy as np from fastapi import FastAPI from pydantic import BaseModel # Load the saved model model = joblib.load('diabetes_model.joblib') # Create FastAPI app app = FastAPI() # Define the input data structure class DiabetesFeatures(BaseModel): age: float sex: float bmi: float map: float tc: float ldl: float hdl: float tch: float ltg: float glu: float # Define a prediction endpoint @app.post("/predict/") def predict(features: DiabetesFeatures): # Convert input data to a numpy array input_data = np.array([[ # Convert input features to the appropriate format features.age, features.sex, features.bmi, features.map, features.tc, features.ldl, features.hdl, features.tch, features.ltg, features.glu ]]) # Get the prediction probability (probability for "diabetic") prediction_prob = model.predict_proba(input_data) # Set a threshold for predicting 'Diabetic' (e.g., 0.6 probability for being diabetic) threshold = 0.6 is_diabetic = prediction_prob[0][1] > threshold # Use the second column for the probability of class '1' # Return the result based on probability threshold return {"prediction": "Diabetic" if is_diabetic else "Not Diabetic"}
Step 3: Run the FastAPI Server#
To start the FastAPI server, use uvicorn:
uvicorn main:app --reload
Step 4: Test the API#
Once the server is running, you can test the API using curl or through an interactive UI at http://127.0.0.1:8000/docs.
Example curl request:
curl -X 'POST' \ 'http://127.0.0.1:8000/predict/' \ -H 'Content-Type: application/json' \ -d '{ "age": 50.0, "sex": 1.0, "bmi": 25.0, "map": 92.0, "tc": 220.0, "ldl": 120.0, "hdl": 50.0, "tch": 70.0, "ltg": 3.5, "glu": 90.0 }'
Expected Response:
{ "prediction": "Not Diabetic" }
Conclusion#
In this guide demonstrates how to integrate a machine learning model with FastAPI. The key steps were:
Train and save a machine learning model.
Load the model in a FastAPI application.
Expose an endpoint to make predictions based on input data.
Links#
👉 New to App-Generator? Join our 10k+ Community using GitHub One-Click SignIN.
👉 Download products and start fast a new project
👉 Bootstrap your startUp, MVP or Legacy project with a custom development sprint