-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_localizer.py
More file actions
156 lines (131 loc) · 5.03 KB
/
train_localizer.py
File metadata and controls
156 lines (131 loc) · 5.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
"""Train the BeesBook localizer model on the feeder bee dataset.
The localizer is a lightweight fully-convolutional heatmap model (~248K params)
that classifies 128x128 grayscale patches as containing a bee (positive) or not
(negative), with per-class output channels.
Usage
-----
# Train on CVAT patches (default)
python train_localizer.py --dataset /path/to/feeder_bee_datasets_v1
# Train on merged patches (CVAT + HDF5)
python train_localizer.py --dataset /data --variant merged
# Fine-tune from pretrained weights
python train_localizer.py --dataset /data --weights /path/to/localizer_2019_weights.pt
"""
import argparse
from datetime import datetime
from pathlib import Path
import mosaic.tracking.pose_training as pose
from mosaic.tracking.pose_training import LocalizerAugmentConfig
import config
def parse_args():
p = argparse.ArgumentParser(
description="Train the BeesBook localizer on feeder bee patch data.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
d = config.LOCALIZER_DEFAULTS
p.add_argument(
"--dataset", required=True,
help="Path to extracted feeder_bee_datasets_v1/ directory",
)
p.add_argument(
"--variant", default="cvat", choices=["cvat", "merged"],
help="Dataset variant to train on",
)
p.add_argument("--epochs", type=int, default=d["epochs"])
p.add_argument("--batch-size", type=int, default=d["batch_size"])
p.add_argument("--lr", type=float, default=d["lr"], help="Learning rate")
p.add_argument(
"--patience", type=int, default=d["early_stopping_patience"],
help="Early stopping patience (epochs)",
)
p.add_argument(
"--lr-patience", type=int, default=d["lr_patience"],
help="ReduceLROnPlateau patience (epochs)",
)
p.add_argument("--device", default=None, help="Device: '0' (cuda), 'mps', 'cpu'")
p.add_argument(
"--freeze-encoder", action="store_true",
help="Freeze encoder weights (train head only)",
)
p.add_argument(
"--weights", default=None,
help="Pretrained weights path (.pt or .h5). "
"Keras .h5 files are auto-converted to PyTorch.",
)
p.add_argument("--name", default=None, help="Run name (auto-generated if omitted)")
p.add_argument(
"--output-dir", default=None,
help="Base output directory (default: dataset models/localizer/<variant>/runs/)",
)
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
if args.device is None:
args.device = config.auto_device()
if args.name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.name = f"{args.variant}_{timestamp}"
return args
def main():
args = parse_args()
# Resolve dataset
dataset_dir = config.resolve_localizer_data(args.dataset, args.variant)
print(f"Dataset: {dataset_dir}")
print(f"Variant: {args.variant}")
print(f"Device: {args.device}")
print()
# Handle pretrained weights
weights = args.weights
if weights is not None:
weights_path = Path(weights)
if weights_path.suffix == ".h5":
pt_path = weights_path.with_suffix(".pt")
if pt_path.exists():
print(f"Using existing PyTorch weights: {pt_path}")
weights = str(pt_path)
else:
print(f"Converting Keras weights: {weights_path} -> {pt_path}")
weights = str(
pose.convert_keras_weights(
weights_path,
output_pt_path=pt_path,
num_classes=config.NUM_CLASSES,
initial_channels=config.INITIAL_CHANNELS,
)
)
# Augmentation config (matches notebook defaults)
augment = LocalizerAugmentConfig(flip_h=True, flip_v=True, rotate_90=True)
# Resolve output directory (default: under the dataset)
if args.output_dir is None:
output_dir = str(config.resolve_localizer_output(args.dataset, args.variant))
else:
output_dir = str(Path(args.output_dir).resolve())
# Train
print(f"Starting training: {args.name}")
print(f"Output: {output_dir}/{args.name}")
print()
result = pose.train_localizer(
dataset_dir=dataset_dir,
num_classes=config.NUM_CLASSES,
initial_channels=config.INITIAL_CHANNELS,
weights=weights,
freeze_encoder=args.freeze_encoder,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
early_stopping_patience=args.patience,
lr_patience=args.lr_patience,
device=args.device,
project=output_dir,
name=args.name,
seed=args.seed,
augment=augment,
)
# Summary
print(f"\nTraining complete.")
print(f" Best model: {result.best_model_path}")
print(f" Best epoch: {result.best_epoch + 1}")
print(f" Best val loss: {result.best_val_loss:.4f}")
print(f" Run dir: {result.run_dir}")
if __name__ == "__main__":
main()