diff --git a/test/utils/test_eval.py b/test/utils/test_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..118689322659d75057126feb22e738457f13a414 --- /dev/null +++ b/test/utils/test_eval.py @@ -0,0 +1,119 @@ +import sys +sys.path.append('/Users/emoebel/code/python/deep-finder/deepfinder/') + +import numpy as np +import utils.objl as ol +import utils.eval as ev + +import unittest +import copy + +objl_true = ol.read('../../examples/training/in/object_list_tomo0.xml') + +tomoid_list = [obj['tomo_idx'] for obj in objl_true] +tomoid_list = np.unique(tomoid_list) +dset_true = {} +for tomoid in tomoid_list: + dset_true[tomoid] = {'object_list': ol.get_tomo(objl_true, tomoid)} + +dset_pred = dset_true + + + + +# Create dummy inputs: +def create_dummy_objl(n_obj=100, mono_class=True): + objl = [] + for _ in range(n_obj): + x = np.random.randint(0, 500) + y = np.random.randint(0, 500) + z = np.random.randint(0, 200) + if mono_class: + label = 1 + else: + label = np.random.randint(1, 4) + cluster_size = np.random.uniform(0, 1) + objl = ol.add_obj(objl, label=label, coord=(z,y,x), cluster_size=cluster_size) + return objl + +def create_dummy_data_set(n_tomos=5, n_obj=100, mono_class=True): + dset = {} + for idx in range(n_tomos): + key = 'tomo'+str(idx) + dset[key] = {'object_list': create_dummy_objl(n_obj, mono_class)} + return dset + +dset_true = create_dummy_data_set(n_tomos=5, n_obj=100, mono_class=False) +dset_pred = dset_true + +detect_eval = ev.Evaluator(dset_true, dset_pred, dist_thr=5).get_evaluation(score_thr=None) + +class TestEvaluator(unittest.TestCase): + + def test_identity(self): # here we test dset_true to itself. Should give perfect scores + dset_true = create_dummy_data_set(n_tomos=5, n_obj=100, mono_class=False) + dset_pred = dset_true + + detect_eval = ev.Evaluator(dset_true, dset_pred, dist_thr=0).get_evaluation(score_thr=None) + + self.assertEqual(detect_eval['global']['f1s'][1], 1) # f1 scores for all classes should be =1 + self.assertEqual(detect_eval['global']['f1s'][2], 1) + self.assertEqual(detect_eval['global']['f1s'][3], 1) + + self.assertEqual( + len(detect_eval['tomo0']['objl_tp']), # n_tp should be = n_true + len(detect_eval['tomo0']['objl_true']), + ) + self.assertEqual(len(detect_eval['tomo0']['objl_fp']), 0) # n_fp should be 0 + self.assertEqual(len(detect_eval['tomo0']['objl_fn']), 0) # n_fn should be 0 + + def test_recall(self): + dset_true = create_dummy_data_set(n_tomos=1, n_obj=100, mono_class=True) + dset_pred = copy.deepcopy(dset_true) + + # Delete 10 elements from pred: + n_delete = 10 + for _ in range(n_delete): + del dset_pred['tomo0']['object_list'][0] + + detect_eval = ev.Evaluator(dset_true, dset_pred, dist_thr=0).get_evaluation(score_thr=None) + + # Hence recall should be 0.9 and precision should be 1 + self.assertEqual(detect_eval['tomo0']['rec'][1], 0.9) + self.assertEqual(detect_eval['tomo0']['pre'][1], 1.0) + + # Also, following should be true: n_tp=90, n_fp=0, n_fn=10: + self.assertEqual(len(detect_eval['tomo0']['objl_tp']), 90) + self.assertEqual(len(detect_eval['tomo0']['objl_fp']), 0) + self.assertEqual(len(detect_eval['tomo0']['objl_fn']), 10) + + # Global scores should also be as follow: + self.assertEqual(detect_eval['global']['rec'][1], 0.9) + self.assertEqual(detect_eval['global']['pre'][1], 1.0) + + def test_precision(self): + dset_true = create_dummy_data_set(n_tomos=1, n_obj=100, mono_class=True) + dset_pred = copy.deepcopy(dset_true) + + # Delete 10 elements from true: + n_delete = 10 + for _ in range(n_delete): + del dset_true['tomo0']['object_list'][0] + + detect_eval = ev.Evaluator(dset_true, dset_pred, dist_thr=0).get_evaluation(score_thr=None) + + # Hence recall should be 1 and precision should be 0.9 + self.assertEqual(detect_eval['tomo0']['rec'][1], 1.0) + self.assertEqual(detect_eval['tomo0']['pre'][1], 0.9) + + # Also, following should be true: n_tp=90, n_fp=10, n_fn=0: + self.assertEqual(len(detect_eval['tomo0']['objl_tp']), 90) + self.assertEqual(len(detect_eval['tomo0']['objl_fp']), 10) + self.assertEqual(len(detect_eval['tomo0']['objl_fn']), 0) + + # Global scores should also be as follow: + self.assertEqual(detect_eval['global']['rec'][1], 1.0) + self.assertEqual(detect_eval['global']['pre'][1], 0.9) + +if __name__ == '__main__': + unittest.main()