- Benchmarking on FER2013
- Benchmarking on JAFFE
- Installation
- Download datasets
- Training on FER2013
- Evaluation results
| Model | Accuracy |
|---|---|
| VGG19 | 70.77 |
| EfficientNet_b2b | 70.83 |
| Googlenet | 71.97 |
| Resnet34 | 72.42 |
| Inception_v3 | 72.72 |
| Resnet50 | 72.86 |
| Cbam_Resnet50 | 72.95 |
| Bam_Resnet50 | 73.14 |
| Densenet121 | 73.16 |
| Resnet152 | 73.22 |
| Resnet101 | 74.06 |
| ResMaskingNet | 74.14 |
| ResMaskingNet + 6 | 76.82 |
| Model | Accuracy |
|---|---|
| EfficientNet_b2b | 90 |
| Resnet18 | 91.42 |
| Resnet101 | 92.86 |
| Resnet50 | 94.28 |
| Cbam_Resnet50 | 94.29 |
| ResMaskingNet | 97.1 |
| Ensemble of above CNNs | 98.2 |
| Model | Accuracy |
|---|---|
| Resnet18 | 96.631 |
| Densenet121 | 97.573 |
| VGG19 | 98.058 |
| Resnet101 | 98.544 |
| ResNet50_pretrainedvgg | 98.544 |
| ResMaskingNet | 98.87 |
- Install PyTorch by selecting your environment on the website and running the appropriate command.
- Clone this repository and install package prerequisites below.
- Then download the dataset by following the instructions below.
- Python 3.6+
- PyTorch 1.3+
- Torchvision 0.4.0+
- requirements.txt
- FER2013 Dataset (locate it in
saved/data/fer2013likesaved/data/fer2013/train.csv) - JAFFE Dataset (locate it in
saved/data/jaffelikesaved/data/jaffe/train.csv)
- To train network, you need to specify model name and other hyperparameters in config file (located at configs/*) then ensure it is loaded in main file, then run training procedure by simply running main file, for example:
python main_fer.py # Example for fer2013_config.json file-
The best checkpoints will chosen at term of best validation accuracy, located at
saved/checkpoints -
The TensorBoard training logs are located at
saved/logs, to open it, usetensorboard --logdir saved/logs/ -
By default, it will train
alexnetmodel, you can switch to another model by editingconfigs/fer2013\_config.jsonfile (toresnet18orcbam\_resnet50or my networkresmasking\_dropout1.Follow similar process for JAFFE dataset
Below is an example for generating a striking confusion matrix writing things in latex.
(Read this article for more information, there will be some bugs if you blindly run the code without reading).
python ./Visualization/gen_confusion_matrix.pyBelow is an example, to generate salient maps for jaffe images to find most import parts of the image used by the model for prediction.
python ./Visualization/gen_salientmap.pyBelow is an example, to generate gradCAM visualization to find most import parts of the image used by the model for prediction.
python ./Visualization/gradCAM_resmasking.pyI used no-weighted sum avarage ensemble method to fuse 7 different models together, to reproduce results, you need to do the following steps:
- Download all needed trained weights and located on
./saved/checkpoints/directory. Link to download can be found on Benchmarking section. - Edit file
gen_resultsand run it to generate result offline for each model. - Run
gen_ensemble.pyfile to generate accuracy for example methods.
