This repository contains a Python script that builds and trains an AI model to predict maternal health risk levels (high risk, mid risk, low risk) using an XGBoost classifier. The model utilizes a dataset with features such as Age, SystolicBP, DiastolicBP, BS (Blood Sugar), BodyTemp, and HeartRate, and provides visualizations to understand the model's performance and data distribution.
The following flowchart illustrates the complete machine learning pipeline used in this project:
flowchart TD
A["Load Dataset<br/>Maternal Health Risk Data Set.csv"] --> B["Data Preprocessing<br/>Drop missing values<br/>Encode target variable"]
B --> C["Feature Selection<br/>Features: Age, SystolicBP, DiastolicBP,<br/>BS, BodyTemp, HeartRate<br/>Target: RiskLevel"]
C --> D["Compute Class Weights<br/>Handle class imbalance"]
D --> E["Data Splitting<br/>80% Train+Val / 20% Test<br/>Then 80% Train / 20% Val"]
E --> F["Create DMatrix Objects<br/>For XGBoost training"]
F --> G["Hyperparameter Tuning<br/>GridSearchCV with 5-fold CV<br/>Optimize: max_depth, learning_rate,<br/>n_estimators, regularization"]
G --> H["Train Best Model<br/>XGBoost with early stopping<br/>on validation set"]
H --> I["Model Evaluation<br/>Calculate accuracy per iteration<br/>for training and validation"]
I --> J["Test Set Prediction<br/>Using best iteration"]
J --> K["Performance Metrics<br/>Accuracy Score<br/>Classification Report<br/>Confusion Matrix"]
K --> L["Create Visualizations<br/>Feature Importance Chart<br/>Risk Distribution Pie Chart<br/>Confusion Matrix Heatmap<br/>Accuracy Line Graph<br/>Log Loss Graph"]
L --> M["Cross-Validation<br/>5-fold CV with f1-weighted scoring"]
M --> N["Save Model and Encoder<br/>maternal_risk_model_improved.pkl<br/>label_encoder.pkl"]
N --> O["Create Prediction Function<br/>predict_maternal_risk function"]
O --> P["Example Prediction<br/>Test with sample data"]
style A fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style B fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style C fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style D fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style E fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style F fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style G fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style H fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style I fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style J fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style K fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style L fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style M fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style N fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style O fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
style P fill:#000000,stroke:#000000,stroke-width:3px,color:#ffffff
The script leverages the following libraries:
- pandas for data manipulation
- scikit-learn for preprocessing, train-test splitting, and evaluation metrics
- xgboost for the classification model
- joblib for saving the trained model and label encoder
- matplotlib and seaborn for generating visualizations
The model is trained on the provided Maternal Health Risk Data Set.csv and includes visualizations to analyze feature importance, risk level distribution, confusion matrix, and training/validation accuracy over iterations.
- Python 3.x
- Required libraries (install via pip):
pip install pandas numpy scikit-learn xgboost joblib matplotlib seabornThe dataset (Maternal Health Risk Data Set.csv) should be placed in the same directory as the script. It contains the following columns:
- Age
- SystolicBP
- DiastolicBP
- BS (Blood Sugar)
- BodyTemp
- HeartRate
- RiskLevel (target variable: high risk, mid risk, low risk)
- Clone the repository or copy the script to your local machine
- Install the required dependencies using the command above
- Ensure the dataset file is in the working directory
-
Run the script:
python maternal_risk_prediction_with_line_graph.py
-
The script will:
- Train the XGBoost model on the dataset
- Generate predictions and evaluate the model
- Save the trained model (
maternal_risk_model.pkl) and label encoder (label_encoder.pkl) - Create and save visualizations as PNG files
- Provide an example prediction for sample input data
-
Output:
- Console output includes accuracy, classification report, and an example prediction
- Visualizations are saved as
feature_importance.png,risk_distribution.png,confusion_matrix.png, andaccuracy_line_graph.png
The script generates the following visualizations to aid in understanding the model and data:
- Displays the relative importance of each feature in the XGBoost model
- Helps identify which factors (e.g., SystolicBP, BS) most influence the risk prediction
- Shows the proportion of high risk, mid risk, and low risk cases in the dataset
- Useful for assessing class balance
- Visualizes the model's performance by showing correct and incorrect predictions
- Example illustrates the number of instances correctly classified (e.g., 41 high risk cases predicted as high risk) and misclassified (e.g., 4 low risk cases predicted as high risk)
- Plots the training and validation accuracy over training iterations
- Indicates how the model's accuracy evolves, with training accuracy (blue) and validation accuracy (red) tracked
- Example shows training accuracy reaching ~0.95 and validation accuracy stabilizing around 0.85-0.90 after 40 iterations
The script includes a function predict_maternal_risk to predict risk levels for new data. An example is provided with input values:
Age=25, SystolicBP=130, DiastolicBP=80, BS=15, BodyTemp=98, HeartRate=86The predicted risk level is printed to the console.
maternal_risk_model.pkl: The trained XGBoost modellabel_encoder.pkl: The label encoder used to transform risk levels
These can be loaded later for predictions without retraining.
- The line graph tracks accuracy by retraining the model at each iteration with a validation set
- The default number of boosting rounds is 100, but this can be adjusted by modifying the
n_estimatorsparameter in theXGBClassifier - Visualizations are saved in the working directory for easy review
- Ensure the dataset file is correctly formatted and contains no missing values (the script drops NA values by default)
- Add early stopping to optimize the number of iterations
- Include hyperparameter tuning (e.g., using GridSearchCV) to improve model performance
- Extend visualizations to include learning curves or ROC curves for more detailed analysis
This project is for educational purposes. Feel free to modify and distribute, but please credit the original author.




