-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
65 lines (49 loc) · 1.95 KB
/
Copy pathtest.py
File metadata and controls
65 lines (49 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torchvision import transforms
from tqdm import tqdm
import glob as glob
from model.model import CSRNet
import json
import dataset
import Config as cfg
from argparse import ArgumentParser
from rich.progress import track
import pytorch_lightning as pl
def parse_args():
parser = ArgumentParser()
parser.add_argument('--model', type=str, help='Path to the model')
return parser.parse_args()
def test(args):
# ================== Data ==================
with open(cfg.test_json) as f:
val_list = json.load(f)
val_dataset = dataset.listDataset(val_list,
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]),
train=False,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True)
# ================== Model ==================
model = CSRNet.load_from_checkpoint(args.model, learning_rate=cfg.learning_rate)
model.eval()
# ================== Test ==================
mae = 0
with torch.no_grad():
for i, (img, target) in track(enumerate(val_loader), total=len(val_loader)):
output = model(img)
mae += abs(output.sum().item() - target.sum().item())
# Test results
print(f"----- Test results -----")
print(f"MAE: {mae / len(val_dataset)}")
if __name__ == '__main__':
args = parse_args()
test(args)