"""
Evaluation module.
Used to compare predictions and groundtruth.
Examples
--------
.. code-block:: bash
python -m biom3d.eval -p MyPred -l MyMasks --num_classes 2
python -m biom3d.eval -p MyPred -l MyMasks -f IoU --num_classes 2
Or in python
.. code-block:: python
print(eval("./MyPred","./MyMasks",2))
print(eval("./MyPred","./MyMasks",2,iou))
"""
import numpy as np
import argparse
from functools import partial
from biom3d.utils import versus_one, dice, iou, DataHandlerFactory, MONAIMetricFactory, absolute_volume_difference
from typing import Callable
[docs]
def robust_sort(str_list: list[str]) -> list[str]:
"""
Perform a robust sorting of a list of strings, useful for sorting file paths.
The sorting pads strings with zeros at the beginning so that all have the same length,
then sorts lexicographically, and finally removes the padding.
Parameters
----------
str_list : list of str
List of strings to sort.
Returns
-------
list of str
The sorted list of strings.
"""
# max string lenght
max_len = max([len(s) for s in str_list])
# add zeros in the beginning so that all strings have the same length
# associate with the original length to the elongated string
same_len = {'0'*(max_len-len(s))+s:len(s) for s in str_list}
# sort the dict by key
sorted_same_len = {k:same_len[k] for k in sorted(same_len)}
# remove zeros and return
return [k[max_len-v:] for k,v in sorted_same_len.items()]
[docs]
def eval(path_lab: str,
path_out: str,
num_classes: int,
fct: Callable = dice,
) -> tuple[list[float], float]:
"""
Evaluate segmentation results by comparing predictions to labels using a given metric.
Parameters
----------
path_lab : str
Path to the folder containing label images.
path_out : str
Path to the folder containing predicted images.
num_classes : int
Number of classes for evaluation.
fct : Callable, optional
Metric function to compute (default is `dice`).
Returns
-------
results: list of float
List of metric results per image.
mean: float
Average of results
mean_class: Optional, list
Average per class
"""
print("Start evaluation")
handler1 = DataHandlerFactory.get(
path_lab,
read_only=True,
eval='label',
)
handler2 = DataHandlerFactory.get(
path_out,
read_only=True,
eval='pred',
)
assert len(handler1) == len(handler2), f"[Error] Not the same number of labels and predictions! '{len(handler1)}' for '{len(handler2)}'"
results = []
is_multi_class = False
for (img1,_,_,),(img2,_,_) in zip(handler1,handler2):
print("Metric computation for:", img1,img2)
res = versus_one(
fct=fct,
input_img=handler1.load(img1)[0],
target_img=handler2.load(img2)[0],
num_classes=num_classes+1,
single_class=None)
if isinstance(res, (list, np.ndarray)) and len(res) > 1: is_multi_class = True
print("Metric result:", res)
results += [res]
if is_multi_class:
print(f"Evaluation done! Per class-results: {np.mean(results, axis=0)}, Average result: {np.mean(results)}")
return results, np.mean(results), np.mean(results, axis=0)
else:
print("Evaluation done! Average result:", np.mean(results))
return results, np.mean(results)
if __name__=='__main__':
supported_function = {
"dice":dice,
"equals":np.equal,
"iou":iou,
"avd": absolute_volume_difference,
"rve": partial(absolute_volume_difference, relative=True)
}
parser = argparse.ArgumentParser(description="Prediction evaluation.")
parser.add_argument("-p", "--path_pred","--dir_pred",dest="path_pred", type=str, default=None,
help="Path to the prediction collection")
parser.add_argument("-l", "--path_lab","--dir_lab",dest="path_lab", type=str, default=None,
help="Path to the label collection")
parser.add_argument("-f", "--function",dest="function", type=str, default='dice',
help=f"(default=dice) Function used for evaluation. "
f"Supported : {', '.join(supported_function.keys())} "
f"or one of MONAI metrics classes found in https://monai-dev.readthedocs.io/en/latest/metrics.html"
)
parser.add_argument("--num_classes", type=int, default=1,
help="(default=1) Number of classes (types of objects) in the dataset. The background is not included.")
args = parser.parse_args()
if args.function not in supported_function:
try:
eval(args.path_lab, args.path_pred, args.num_classes,MONAIMetricFactory(args.function))
except Exception as e:
print(
f"Function '{args.function}' not supported. "
f"Supported functions :'{supported_function.keys()}' "
f"or one of MONAI metrics classes found in https://monai-dev.readthedocs.io/en/latest/metrics.html")
print(f"Error during eval: {e}")
exit(1)
else:
eval(args.path_lab, args.path_pred, args.num_classes,supported_function[args.function])