|
| 1 | +#!/usr/bin/python |
| 2 | +# -*- encoding: utf-8 -*- |
| 3 | +''' |
| 4 | +@File : merge_ovdg_preds.py |
| 5 | +@Version : 1.0 |
| 6 | +@Time : 2024/09/12 16:10:56 |
| 7 | +@E-mail : daodao123@sjtu.edu.cn |
| 8 | +@Introduction : None |
| 9 | +''' |
| 10 | + |
| 11 | +import os |
| 12 | +import numpy as np |
| 13 | +import torch |
| 14 | +from mmrotate.structures.bbox import RotatedBoxes |
| 15 | +import json |
| 16 | +import argparse |
| 17 | +from tqdm import tqdm |
| 18 | +import torch |
| 19 | + |
| 20 | +if __name__ == '__main__': |
| 21 | + parser = argparse.ArgumentParser(description="Merge OVD pseudo labels.") |
| 22 | + parser.add_argument("--ann_path", type=str, required=True, help="Path to the original json path") |
| 23 | + parser.add_argument("--pred_path", type=str, required=True, help="Path to the prediction path (.box.json)") |
| 24 | + parser.add_argument("--save_path", type=str, required=True, help="Path where the pseudo labels will be saved") |
| 25 | + parser.add_argument("--score_thr", type=float, default=0.02, help="Scores lower than score_thr will be filtered.") |
| 26 | + parser.add_argument("--topk", type=int, default=None, help="Choose top-k bboxes.") |
| 27 | + |
| 28 | + args = parser.parse_args() |
| 29 | + |
| 30 | + id2pred = {} |
| 31 | + with open(args.pred_path, 'r') as f: |
| 32 | + data_list = json.loads(f.read()) |
| 33 | + |
| 34 | + for item in data_list: |
| 35 | + id = item["image_id"] |
| 36 | + if id not in id2pred: |
| 37 | + id2pred[id] = { |
| 38 | + 'bboxes': [item['bbox']], |
| 39 | + 'scores': [item['score']] |
| 40 | + } |
| 41 | + else: |
| 42 | + id2pred[id]['bboxes'].append(item['bbox']) |
| 43 | + id2pred[id]['scores'].append(item['score']) |
| 44 | + |
| 45 | + with open(args.ann_path, 'r') as fp: |
| 46 | + data = json.load(fp) |
| 47 | + |
| 48 | + ann_id = 0 |
| 49 | + for meta in tqdm(data["images"]): |
| 50 | + cid = meta["category_id"] |
| 51 | + image_id = meta["id"] |
| 52 | + if image_id in id2pred: |
| 53 | + scores = torch.tensor(id2pred[image_id]["scores"]) |
| 54 | + if args.topk is None: |
| 55 | + filter = scores > args.score_thr |
| 56 | + else: |
| 57 | + filter = scores.topk(args.topk).indices |
| 58 | + bboxes = RotatedBoxes(torch.tensor(id2pred[image_id]["bboxes"]))[filter] |
| 59 | + areas = bboxes.areas.tolist() |
| 60 | + qboxes = bboxes.convert_to('qbox').tensor.tolist() |
| 61 | + hboxes = bboxes.convert_to('hbox').tensor.tolist() |
| 62 | + for area, qbox, hbox in zip(areas, qboxes, hboxes): |
| 63 | + qbox = list(map(int, qbox)) |
| 64 | + hbox = list(map(int, hbox)) |
| 65 | + data["annotations"].append({ |
| 66 | + "id": ann_id, |
| 67 | + "area": area, |
| 68 | + "category_id": cid, |
| 69 | + "segmentation": [qbox], |
| 70 | + "iscrowd": 0, |
| 71 | + "bbox": hbox, |
| 72 | + "image_id": image_id |
| 73 | + }) |
| 74 | + ann_id += 1 |
| 75 | + |
| 76 | + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) |
| 77 | + print(f"Total of {len(data['annotations'])} instances.") |
| 78 | + |
| 79 | + with open(args.save_path, 'w') as fp: |
| 80 | + json.dump(data, fp) |
0 commit comments