This code is a PyTorch implementation of our ICLR'26 paper "Task-Aware Data Selection via Proxy-Label Enhanced Distribution Matching for LLM Finetuning". [paper]
The pipeline consists of 4 main steps:
- Step 1: Dataset splitting and target annotation
- Step 2: Tag clustering and propagation
- Step 3: Keyword extraction and data scoring
- Step 4: Quality-based task-oriented selection
# eval data
bash data/prepare_eval_data.sh
# train data
bash data/prepare_train_data.shSplit evaluation datasets into target (20%) and evaluation (80%) sets.
cd step1_generate_labels
python 1dataset_splitter.pyMerge all target datasets into a unified parquet format.
python 2target_split_merge.pyAnnotate target dataset with tags using Qwen2.5-7B-Instruct.
python 3target_annotation.pyCluster and deduplicate tags from target annotation.
cd ../step2_clustering_and_propagating
python 1testset_tag_cluster_merge.pyGenerate embeddings for training set content using BGE-M3 model.
python 2training_set_content_embedding_cache.pyPropagate clustered tags to training set using semantic similarity.
python 3propagating_tags_using_cached_embedding.pyExtract keywords from clustered tags using vLLM and Qwen2.5-7B-Instruct.
cd ../step3_tag_clustering_label_training_set
python 1keyword_extraction_vllm.pyScore training data using keyword mapping and Qwen2.5-7B-Instruct.
python 2score_based_on_anchors.pyFilter out-of-distribution samples based on score thresholds.
python 3filter_ood.pySelect high-quality samples using distribution-based sampling.
cd ../step4_quality_task_orient
python 1mitigating_domain_shift.py
python 2joint_filter_fusion.py/
├── step1_generate_labels/
│ ├── 1dataset_splitter.py # Split evaluation datasets
│ ├── 2target_split_merge.py # Merge target datasets
│ └── 3target_annotation.py # Annotate target data
├── step2_clustering_and_propagating/
│ ├── 1testset_tag_cluster_merge.py # Cluster and deduplicate tags
│ ├── 2training_set_content_embedding_cache.py # Cache embeddings
│ └── 3propagating_tags_using_cached_embedding.py # Propagate tags
├── step3_tag_clustering_label_training_set/
│ ├── 1keyword_extraction_vllm.py # Extract keywords
│ ├── 2score_based_on_anchors.py # Score data
│ └── 3filter_ood.py # Filter OOD samples
├── step4_quality_task_orient/
│ └── 1mitigating_domain_shift.py # Quality-based selection
| └── 2joint_filter_fusion.py # match multiple label domain
├── data/ # Data directory
| ├── eval/ # Evaluation datasets
| ├── train_embeds_and_tags/ # Cached embeddings
| ├── prepare_*_data.sh # Prepare datasets
| └── *.pt, *.json # Processed data files
└── consistency_precision_reuslt/ # consistency and precision
├── consistency_gpt_eval.xlsx
├── consistency_human_eval.xlsx
├── precision_gpt_eval.xlsx
└── precision_human_eval.xlsx
The finetune & eval of TADS is based on open-instruct.