Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 0cf2275b authored by rickymwalsh's avatar rickymwalsh
Browse files

add option to change IoU threshold

parent ad296fb0
Branches
No related tags found
3 merge requests!13Adding IoU thresh option and thresh=0.5 evaluation,!12Add eval 0.5thresh,!11add option to change IoU threshold
......@@ -84,9 +84,9 @@ def trapezoidal_interpolation(lower_value, upper_value, lower_metric, upper_metr
return lower_value
def filter_table_and_get_metrics(args, tableGT, tablePred, threshold, i=0):
def filter_table_and_get_metrics(args, tableGT, tablePred, threshold, i=0, iou_threshold=0.2):
filtered_tableGT, filtered_tablePred = filter_by_proba_threshold(tableGT, tablePred, threshold)
image_level_df = compute_image_level_metrics(args, filtered_tableGT, filtered_tablePred)
image_level_df = compute_image_level_metrics(args, filtered_tableGT, filtered_tablePred, iou_threshold)
if args.debug:
image_tsv_path = os.path.join(args.output_path, f'image_level_{threshold}_iteration_{i}.tsv')
......@@ -95,7 +95,7 @@ def filter_table_and_get_metrics(args, tableGT, tablePred, threshold, i=0):
return calculate_average_metrics(image_level_df)
def binary_search_for_threshold(args, target_metric, tableGT, tablePred, metric_type):
def binary_search_for_threshold(args, target_metric, tableGT, tablePred, metric_type, iou_threshold=0.2):
# Extract sorted unique probability values from predicted probabilities
unique_probas = sorted(tablePred['predicted_proba'].unique())
unique_probas.insert(0,0)
......@@ -114,7 +114,8 @@ def binary_search_for_threshold(args, target_metric, tableGT, tablePred, metric_
proba_threshold = unique_probas[mid]
# Filter tables by current threshold, and calculate image-level and average metrics
mean_sensitivity, mean_precision, mean_f1_score, mean_fp_per_image = filter_table_and_get_metrics(args, tableGT, tablePred, proba_threshold, i)
mean_sensitivity, mean_precision, mean_f1_score, mean_fp_per_image = \
filter_table_and_get_metrics(args, tableGT, tablePred, proba_threshold, i, iou_threshold)
# Select the metric of interest
current_metric = {
......@@ -131,7 +132,8 @@ def binary_search_for_threshold(args, target_metric, tableGT, tablePred, metric_
proba_threshold, mean_sensitivity, mean_precision, mean_f1_score, mean_fp_per_image
mid -= 1
proba_threshold = unique_probas[mid]
mean_sensitivity, mean_precision, mean_f1_score, mean_fp_per_image = filter_table_and_get_metrics(args, tableGT, tablePred, proba_threshold)
mean_sensitivity, mean_precision, mean_f1_score, mean_fp_per_image = \
filter_table_and_get_metrics(args, tableGT, tablePred, proba_threshold, i, iou_threshold)
current_metric = {
'mean_fp_per_image': mean_fp_per_image,
'mean_sensitivity': mean_sensitivity,
......@@ -161,19 +163,20 @@ def binary_search_for_threshold(args, target_metric, tableGT, tablePred, metric_
return proba_threshold, mean_sensitivity, mean_precision, mean_f1_score
def single_binary_search(args, gt_path, pred_path, target_metric, metric_type, output_path):
def single_binary_search(args, gt_path, pred_path, target_metric, metric_type, output_path, iou_threshold=0.2):
# Load data
tableGT, tablePred = load_data(gt_path, pred_path)
# Binary search for optimal threshold based on target metric
final_threshold, mean_sensitivity, mean_precision, mean_f1_score = binary_search_for_threshold(
args, target_metric, tableGT, tablePred, metric_type
args, target_metric, tableGT, tablePred, metric_type, iou_threshold
)
if args.debug:
# Save the final image-level metrics to TSV
filtered_tableGT_final, filtered_tablePred_final = filter_by_proba_threshold(tableGT, tablePred, final_threshold)
image_level_df_final = compute_image_level_metrics(args, filtered_tableGT_final, filtered_tablePred_final)
image_level_df_final = compute_image_level_metrics(args, filtered_tableGT_final, filtered_tablePred_final,
iou_threshold)
image_tsv_path = os.path.join(output_path, f'image_level_{metric_type}_{target_metric}_{final_threshold}.tsv')
image_level_df_final.to_csv(image_tsv_path, sep='\t', index=False)
......
......@@ -8,14 +8,14 @@ import matplotlib.pyplot as plt
from .binary_search import single_binary_search
def calculate_froc(args, gt_path, pred_path, target_metric_list, metric_type, output_path):
def calculate_froc(args, gt_path, pred_path, target_metric_list, metric_type, output_path, iou_threshold=0.2):
mean_sensitivities = []
results = {}
# Run binary search for each target metric and collect sensitivities
for target_metric in target_metric_list:
final_threshold, mean_sensitivity, mean_precision, mean_f1_score = single_binary_search(
args, gt_path, pred_path, target_metric, metric_type, output_path
args, gt_path, pred_path, target_metric, metric_type, output_path, iou_threshold
)
mean_sensitivities.append(mean_sensitivity)
result_tmp = {
......@@ -68,6 +68,7 @@ if __name__ == "__main__":
parser.add_argument('-m', '--metric_type', required=False, type=str, default='mean_fp_per_image', help='Metrics to reach.')
parser.add_argument('-tm', '--target_metric_list', required=False, type=list, default=[0.25, 0.5, 1, 2, 3], help='Value to reach for the selected target metrics.')
parser.add_argument('-o', '--output_path', required=False, default='/dataEvalChallenge/tmp', help='Path to the output folder.')
parser.add_argument('-debug', '--debug', action='store_true', help='Printing some intermediate information.')
args = parser.parse_args()
......
......@@ -6,7 +6,8 @@ from functions.froc_calculation import calculate_froc
def main(args):
generate_lesion_level_files(args, args.gt_path, args.pred_path, args.csv_path, args.iou_tresh, args.output_path)
calculate_froc(args, args.gt_tsv_path, args.pred_tsv_path, args.target_metric_list, args.metric_type, args.output_path)
calculate_froc(args, args.gt_tsv_path, args.pred_tsv_path, args.target_metric_list, args.metric_type,
args.output_path, args.iou_tresh)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment