Skip to content

Commit 32e67b5

Browse files
committed
add data processing script
1 parent 43c0f51 commit 32e67b5

1 file changed

Lines changed: 80 additions & 0 deletions

File tree

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/python
2+
# -*- encoding: utf-8 -*-
3+
'''
4+
@File : merge_ovdg_preds.py
5+
@Version : 1.0
6+
@Time : 2024/09/12 16:10:56
7+
@E-mail : daodao123@sjtu.edu.cn
8+
@Introduction : None
9+
'''
10+
11+
import os
12+
import numpy as np
13+
import torch
14+
from mmrotate.structures.bbox import RotatedBoxes
15+
import json
16+
import argparse
17+
from tqdm import tqdm
18+
import torch
19+
20+
if __name__ == '__main__':
21+
parser = argparse.ArgumentParser(description="Merge OVD pseudo labels.")
22+
parser.add_argument("--ann_path", type=str, required=True, help="Path to the original json path")
23+
parser.add_argument("--pred_path", type=str, required=True, help="Path to the prediction path (.box.json)")
24+
parser.add_argument("--save_path", type=str, required=True, help="Path where the pseudo labels will be saved")
25+
parser.add_argument("--score_thr", type=float, default=0.02, help="Scores lower than score_thr will be filtered.")
26+
parser.add_argument("--topk", type=int, default=None, help="Choose top-k bboxes.")
27+
28+
args = parser.parse_args()
29+
30+
id2pred = {}
31+
with open(args.pred_path, 'r') as f:
32+
data_list = json.loads(f.read())
33+
34+
for item in data_list:
35+
id = item["image_id"]
36+
if id not in id2pred:
37+
id2pred[id] = {
38+
'bboxes': [item['bbox']],
39+
'scores': [item['score']]
40+
}
41+
else:
42+
id2pred[id]['bboxes'].append(item['bbox'])
43+
id2pred[id]['scores'].append(item['score'])
44+
45+
with open(args.ann_path, 'r') as fp:
46+
data = json.load(fp)
47+
48+
ann_id = 0
49+
for meta in tqdm(data["images"]):
50+
cid = meta["category_id"]
51+
image_id = meta["id"]
52+
if image_id in id2pred:
53+
scores = torch.tensor(id2pred[image_id]["scores"])
54+
if args.topk is None:
55+
filter = scores > args.score_thr
56+
else:
57+
filter = scores.topk(args.topk).indices
58+
bboxes = RotatedBoxes(torch.tensor(id2pred[image_id]["bboxes"]))[filter]
59+
areas = bboxes.areas.tolist()
60+
qboxes = bboxes.convert_to('qbox').tensor.tolist()
61+
hboxes = bboxes.convert_to('hbox').tensor.tolist()
62+
for area, qbox, hbox in zip(areas, qboxes, hboxes):
63+
qbox = list(map(int, qbox))
64+
hbox = list(map(int, hbox))
65+
data["annotations"].append({
66+
"id": ann_id,
67+
"area": area,
68+
"category_id": cid,
69+
"segmentation": [qbox],
70+
"iscrowd": 0,
71+
"bbox": hbox,
72+
"image_id": image_id
73+
})
74+
ann_id += 1
75+
76+
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
77+
print(f"Total of {len(data['annotations'])} instances.")
78+
79+
with open(args.save_path, 'w') as fp:
80+
json.dump(data, fp)

0 commit comments

Comments
 (0)