-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
56 lines (48 loc) · 1.66 KB
/
Copy pathinference.py
File metadata and controls
56 lines (48 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import hydra
from hydra.utils import instantiate
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.trainer import Trainer
from src.data_module import DiffusionTrackerDataModule
from src.trainers import BaseTrainer
from src.utils import (
build_training_modules,
load_best_checkpoint,
log_run_metadata,
process_hparams,
resolve_epoch_len,
split_trainer_config,
)
@hydra.main(version_base=None, config_name="inference", config_path="src/configs")
def main(cfg) -> None:
hparams = process_hparams(cfg, print_hparams=False)
logger = instantiate(hparams.logger) if hparams.get("logger") else None
if logger is not None:
log_run_metadata(logger, hparams)
dm = DiffusionTrackerDataModule(hparams.dataset.data, hparams.dataloaders)
dm.setup("test")
pl_trainer_cfg, module_trainer_cfg = split_trainer_config(hparams.trainer)
modules = build_training_modules(
hparams, pl_trainer_cfg.get("train_mode", "inference")
)
diff_trainer = BaseTrainer(
cfg_metrics=hparams.metrics,
vis_cfg=hparams.visual,
**modules,
**module_trainer_cfg,
)
if bool(pl_trainer_cfg.get("test_with_best_checkpoint", True)):
load_best_checkpoint(diff_trainer)
test_epoch_len = resolve_epoch_len(
pl_trainer_cfg.get("test_epoch_len", 1.0),
len(dm.test_dataloader()),
)
trainer = Trainer(
accelerator="gpu",
logger=logger,
callbacks=[RichProgressBar(leave=True)],
enable_progress_bar=True,
limit_test_batches=test_epoch_len,
)
trainer.test(diff_trainer, datamodule=dm)
if __name__ == "__main__":
main()