forked from ANDYZAQ/GF-SAM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_eval.py
More file actions
119 lines (90 loc) · 4.51 KB
/
main_eval.py
File metadata and controls
119 lines (90 loc) · 4.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
r""" Matcher testing code for one-shot segmentation """
import argparse
import os
import torch
import torch.nn.functional as F
import numpy as np
import sys
sys.path.append('./')
from matcher.common.logger import Logger, AverageMeter
from matcher.common.vis import Visualizer
from matcher.common.evaluation import Evaluator
from matcher.common import utils
from matcher.data.dataset import FSSDataset
from matcher.GFSAM import build_model
import random
random.seed(0)
def test(GFSAM, dataloader, args=None):
r""" Test GFSAM """
# Freeze randomness during testing for reproducibility
# Follow HSNet
utils.fix_randseed(0)
average_meter = AverageMeter(dataloader.dataset)
for idx, batch in enumerate(dataloader):
batch = utils.to_cuda(batch)
query_img, query_mask, support_imgs, support_masks = \
batch['query_img'], batch['query_mask'], \
batch['support_imgs'], batch['support_masks']
# 1. GFSAM prepare references and target
GFSAM.set_reference(support_imgs, support_masks)
GFSAM.set_target(query_img)
# 2. Predict mask of target
pred_mask, _ = GFSAM.predict()
GFSAM.clear()
assert pred_mask.size() == batch['query_mask'].size(), \
'pred {} ori {}'.format(pred_mask.size(), batch['query_mask'].size())
# 3. Evaluate prediction
area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
# Visualize predictions
if Visualizer.visualize:
Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
batch['query_img'], batch['query_mask'],
pred_mask, batch['class_id'], idx,
area_inter[1].float() / area_union[1].float())
# Write evaluation results
average_meter.write_result('Test', 0)
miou, fb_iou, _ = average_meter.compute_iou()
return miou, fb_iou
if __name__ == '__main__':
# Arguments parsing
parser = argparse.ArgumentParser(description='GFSAM Pytorch Implementation for One-shot Segmentation')
# Dataset parameters
parser.add_argument('--datapath', type=str, default='datasets')
parser.add_argument('--benchmark', type=str, default='coco',
choices=['fss', 'coco', 'pascal', 'lvis', 'paco_part', 'pascal_part', 'deepglobe', 'isic', 'isaid'])
parser.add_argument('--bsz', type=int, default=1)
parser.add_argument('--nworker', type=int, default=0)
parser.add_argument('--fold', type=int, default=0)
parser.add_argument('--nshot', type=int, default=1)
parser.add_argument('--img-size', type=int, default=1024)
parser.add_argument('--use_original_imgsize', action='store_true')
parser.add_argument('--log-root', type=str, default='output/debug')
parser.add_argument('--visualize', type=int, default=0)
# DINOv2 and SAM parameters
parser.add_argument('--dinov2-size', type=str, default="vit_large")
parser.add_argument('--sam-size', type=str, default="vit_h")
parser.add_argument('--dinov2-weights', type=str, default="models/dinov2_vitl14_pretrain.pth")
parser.add_argument('--sam-weights', type=str, default="models/sam_vit_h_4b8939.pth")
args = parser.parse_args()
if not os.path.exists(args.log_root):
os.makedirs(args.log_root)
Logger.initialize(args, root=args.log_root)
# Device setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.device = device
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
# Model initialization
GFSAM = build_model(args)
# Helper classes (for testing) initialization
Evaluator.initialize()
Visualizer.initialize(args.visualize)
# Dataset initialization
FSSDataset.initialize(img_size=args.img_size, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
# Test GFSAM
with torch.no_grad():
test_miou, test_fb_iou = test(GFSAM, dataloader_test, args=args)
Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
Logger.info('==================== Finished Testing ====================')