Support a sampling strategy for multiple training datasets#107
Conversation
|
@zerovl thanks, couldn't this logic be placed in a dataset wrapper so we don't have repeat the train loop and incur more long term maintenance? Either one that covers both csv & wds or a separate one for each, that handles all of the length calcs, sampling, etc internally for each batch grabbed ... |
|
@rwightman Thanks for replying. We agree with that the logic should be placed in a dataset wrapper. We are working on implementing it, and making sure experiment results are correct. |
|
@rwightman hi, the implementation is done, and the log is attached. It seems that results are almost the same with the former version. Would you check the code when you are available? |
|
@zerovl thanks for updating this and your other PR, I'll try to find some time to take a closer look next week. |
|
@rwightman thanks for your time. I am willing to discuss about implementation details. |
Proposing the debiased sampling method proposed in the ZeroVL paper. When training multiple datasets, the debiased sampling improves the accuracy of CLIP model. It includes a new flag:
Introduction of Debiased Sampling
As shown in Fig2, random sampling is the most intuitive sampling method, which randomly constructs training batches with all available data. However, as shown in Fig3, random sampling leads to biased feature distributions on both image and text modalities.
Debiased sampling ensures instances within each batch come from the same dataset. Training with debiased sampling improves the quality of learned representations, and contributes to better results on many downstream tasks.
Experiments on sampling methods
We use two datasets, CC3M and SBU, to show the improvements of debiased sampling.
Experiment1: Random Sampling
1. Setting & Acc
dataset: CC3M + SBU (2.79M + 0.86M)
batchsize: 2048 (256 per GPU, 8 V100 32GB)
learning rate: 1e-3
weight decay: 0.1
sampling: random
zero-shot acc on ImageNet: top1 21.36, top5 40.98
2. Training script
torchrun --nproc_per_node 8 -m training.main --train-data "/data/cc3m/cc3m_sbu_train_anno.csv" \ --dataset-type auto \ --batch-size 256 \ --precision amp \ --workers 4 \ --imagenet-val "/data/ILSVRC/Data/CLS-LOC/val" \ --csv-separator , \ --lr=1e-3 \ --wd=0.1'/data/cc3m/cc3m_sbu_train_anno.csv' contains all samples from CC3M and SBU.
3. Log
cc3m+sbu+random_sample.log
Experiment2: Debiased Sampling
1. Setting
dataset: CC3M + SBU (2.79M + 0.86M)
batchsize: 2048 (256 per GPU, 8 V100 32GB)
learning rate: 1e-3
weight decay: 0.1
sampling: debias
zero-shot acc on ImageNet: top1 22.33, top5 42.29
2. Training script
torchrun --nproc_per_node 8 -m training.main --train-data "/data/cc3m/cc3m_train_anno.csv, /data/sbu/sbu_train_anno.csv" \ --dataset-type auto \ --batch-size 256 \ --precision amp \ --workers 4 \ --imagenet-val "/data/ILSVRC/Data/CLS-LOC/val" \ --csv-separator , \ --lr=1e-3 \ --wd=0.1 \ --debias-sample3. Log
cc3m+sbu+debias_sample.log