-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
83 lines (59 loc) · 2.77 KB
/
eval.py
File metadata and controls
83 lines (59 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import tensorflow as tf
from tgs import config
from tgs import data
from tgs import model as m
from tgs import analyze as a
import os
tf.logging.set_verbosity(tf.logging.INFO)
# TODO: parameterize these
IMG_DIM = 101
def evaluate(cfg, checkpoint_path, hooks=None):
tf.logging.info('Using data class: %s' % cfg.get('data.class'))
dataset = data.DataInput.get(cfg.get('data.class'))(cfg.get('data'),
batch_size=cfg.get('batch_size'),
num_epochs=1,
label_cnt=cfg.get('model.label_cnt'))
tf.logging.info('Using model: %s' % cfg.get('model.class'))
model = m.BaseModel.get(cfg.get('model.class'))(cfg.get('model'))
resize_method = cfg.get('data.ext.resize_method')
params = {'l2_normalize': cfg.get('l2_normalize'),
'resize_method': resize_method}
if cfg.get('metric.accuracy') is not None:
params['accuracy'] = cfg.get('metric.accuracy')
if cfg.get('metric.map_iou') is not None:
params['map_iou'] = cfg.get('metric.map_iou')
resize_dim = cfg.get('data.ext.resize_dim')
diff = resize_dim - IMG_DIM
mid_padding = diff // 2
resize = [[mid_padding, diff - mid_padding], [mid_padding, diff - mid_padding], [0, 0]]
estimator = tf.estimator.Estimator(model_fn=model.model_fn, config=None, params=params)
evaluation = estimator.evaluate(input_fn=lambda: dataset.input_fn(tf.estimator.ModeKeys.EVAL, resize_param=resize),
steps=cfg.get('valid_steps'),
hooks=hooks,
checkpoint_path=checkpoint_path)
return evaluation
def main(_):
tf.logging.info("Reading config file...")
cfg = config.Config(FLAGS.config_file)
if FLAGS.analyze is not None:
analyze_hook = a.AnalyzeEvaluationHook()
evaluate(cfg, FLAGS.checkpoint_path, hooks=[analyze_hook])
output_dir = os.path.dirname(FLAGS.checkpoint_path)
if FLAGS.analyze == 'unet':
a.analyze_unet(analyze_hook.results_dict, cfg, output_dir=output_dir)
elif FLAGS.analyze == 'mask':
a.analyze_mask(analyze_hook.results_dict, cfg, output_dir=output_dir)
else:
evaluate(cfg, FLAGS.checkpoint_path)
tf.app.flags.DEFINE_string(
'config_file', None,
'File containing the configuration for this evaluation run')
tf.app.flags.DEFINE_string(
'checkpoint_path', None,
'Full path to the checkpoint used to initialize the graph')
tf.app.flags.DEFINE_string(
'analyze', None,
"'None', 'unet' or 'mask' to indicate whether and which type of analysis")
FLAGS = tf.app.flags.FLAGS
if __name__ == '__main__':
tf.app.run()