-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
41 lines (31 loc) · 1.19 KB
/
main.py
File metadata and controls
41 lines (31 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from configs import parse_pretrain_config, parse_adaption_config
from cores.pretrain_trainer import Pretrainer
from downstream.adapt_trainer import AdaptTrainer
import argparse
def pretrain_main(remaining_argv=None):
config = parse_pretrain_config(remaining_argv)
print("Final Configuration:")
for k, v in config.__dict__.items():
print(f" {k}: {v}")
trainer = Pretrainer(config)
trainer.train()
torch.cuda.empty_cache()
def transfer_main(remaining_argv=None):
config = parse_adaption_config(remaining_argv)
print("Final Configuration:")
for k, v in config.__dict__.items():
print(f" {k}: {v}")
trainer = AdaptTrainer(config)
trainer.train()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Pretrain or Adaption Command")
parser.add_argument("--run_type", type=str, default="pretrain", choices=["pretrain", "adapt"])
args, remaining_argv = parser.parse_known_args()
if args.run_type == "pretrain":
pretrain_main(remaining_argv)
elif args.run_type == "adapt":
transfer_main(remaining_argv)
else:
raise ValueError("Invalid run type")