"""
Main module for predictions.
This module contains generic predictions functions:
- pred_single
- pred
- pred multiple
And interface predictions functions made for CLI:
- pred_seg
- pred_seg_eval
- pred_seg_eval_single
"""
import os
import argparse
import pathlib
from typing import Optional
from biom3d.builder import Builder
from biom3d.utils import deprecated, versus_one, dice, DataHandlerFactory
from biom3d.eval import eval
#---------------------------------------------------------------------------
# prediction base fonction
[docs]
def pred_single(log:str|list[str],
img_path:str,
out_path:str,
skip_preprocessing:bool=False,
)->tuple[int,str]:
"""
Predict segmentation or classification on a single image.
Parameters:
-----------
log : str or list of str
Path to the model/log directory or configuration.
img_path : str
Path to the input image file.
out_path : str
Directory where the prediction output will be saved.
skip_preprocessing : bool, default=False
If True, skips preprocessing step.
Returns:
--------
num_classes: int
Number of classes + 1 (the background)
path_out: str
Path to the saved mask output.
"""
if not isinstance(log,list): log=str(log)
builder = Builder(config=None,path=log, training=False)
handler = DataHandlerFactory.get(
img_path,
output=out_path,
msk_outpath = out_path,
model_name = builder.config[-1].DESC if isinstance(builder.config,list) else builder.config.DESC,
)
img = builder.run_prediction_single(handler, return_logit=False,skip_preprocessing=skip_preprocessing)
handler.save(handler.images[0], img,"pred")
return builder.config.NUM_CLASSES+1,handler.msk_outpath # for pred_seg_eval_single
[docs]
def pred(log:str|list[str],
path_in:str,
path_out:str,
skip_preprocessing:bool=False,
)->str:
"""
Predict on all images in a collecion.
Parameters:
-----------
log : str or list of string
Path to the model/log directory or configuration.
path_in : str
Path to collection containing input images.
path_out : str
Path to collection to save prediction outputs.
skip_preprocessing : bool, default=False
If True, skips preprocessing step.
Returns:
--------
str
Path to the output directory containing predictions.
"""
if not isinstance(log,list): log=str(log)
path_in=str(path_in)
path_out=str(path_out)
path_out = os.path.join(path_out,os.path.split(log[0] if isinstance(log,list) else log)[-1]) # name the prediction folder with the last model folder name
builder = Builder(config=None,path=log, training=False)
path_out = builder.run_prediction_folder(path_in=path_in, path_out=path_out, return_logit=False,skip_preprocessing=skip_preprocessing)
return path_out
[docs]
@deprecated("This method is no longer used as it is the default behaviour of DataHandlers.")
def pred_multiple(log:str|list[str],
path_in:str,
path_out:str,
skip_preprocessing:bool=False,
)->str:
"""
Predict on multiple folders of images. DEPRECATED.
This method is deprecated because the default behavior of DataHandlers
now supports multiple folder prediction.
Parameters:
-----------
Same as pred()
Returns:
--------
Same as pred()
"""
return pred(log,path_in,path_out,skip_preprocessing=skip_preprocessing)
#---------------------------------------------------------------------------
# main unet segmentation interface
[docs]
def pred_seg(log:pathlib.Path|str|list[str]=pathlib.Path.home(),
path_in:pathlib.Path | str =pathlib.Path.home(),
path_out:pathlib.Path | str =pathlib.Path.home(),
skip_preprocessing:bool=False
)->None:
"""
Run prediction on a folder of images using default paths.
Parameters:
-----------
log : pathlib.Path, str or list of str, default=home directory
Path to the model or log directory.
path_in : pathlib.Path or str, default=home directory
Path to collection containing images.
path_out : pathlib.Path or str, default=home directory
Path to collection where predictions will be saved.
skip_preprocessing : bool, default=False
If True, skips preprocessing step.
Returns
-------
None
"""
pred(log, path_in, path_out,skip_preprocessing=skip_preprocessing)
# TODO remove eval only, we have a module for that
[docs]
def pred_seg_eval(log:pathlib.Path|str|list[str]=pathlib.Path.home(),
path_in:pathlib.Path | str =pathlib.Path.home(),
path_out:pathlib.Path | str =pathlib.Path.home(),
path_lab:Optional[pathlib.Path | str]=None,
eval_only:bool=False,
skip_preprocessing:bool=False
)->None:
"""
Run prediction on a folder of images and optionally evaluate segmentation (with dice).
Parameters:
-----------
log : pathlib.Path, str or list of str, default=home directory
Path to the model or log directory.
path_in : pathlib.Path or str, default=home directory
Path to collection containing images.
path_out : pathlib.Path or str, default=home directory
Path to collection where predictions will be saved.
path_lab : pathlib.Path or str, optional
Path to collection containing ground-truth label masks for evaluation.
eval_only : bool, default=False
If True, skips prediction and runs evaluation only.
skip_preprocessing : bool, default=False
If True, skips preprocessing step.
Returns
-------
None
"""
print("Start inference")
builder_pred = Builder(
config=None,
path=log,
training=False)
path_out = os.path.join(path_out,os.path.split(log[0] if isinstance(log,list) else log)[-1]) # name the prediction folder with the last model folder name
if not eval_only:
path_out = builder_pred.run_prediction_folder(path_in=path_in, path_out=path_out, return_logit=False,skip_preprocessing=skip_preprocessing) # run the predictions
print("Inference done!")
if path_lab is not None:
if isinstance(builder_pred.config,list):
num_classes = builder_pred.config[0].NUM_CLASSES
else:
num_classes = builder_pred.config.NUM_CLASSES
# eval
eval(path_lab,path_out,num_classes=num_classes)
[docs]
def pred_seg_eval_single(log:str|list[str],
img_path:str,
out_path:str,
msk_path:str,
skip_preprocessing:bool=False
)->None:
"""
Run prediction on a single image and compute evaluation metric (dice) against mask.
Parameters:
-----------
log : str or list of str
Path to the model or log directory.
img_path : str
Path to the input image file.
out_path : str
Directory where prediction output will be saved.
msk_path : str
Path to the ground-truth mask for evaluation.
skip_preprocessing : bool, default=False
If True, skips preprocessing step.
Returns:
--------
None
"""
print("Run prediction for:", img_path)
num_classes,out = pred_single(log, img_path, out_path,skip_preprocessing=skip_preprocessing)
print("Done! Prediction saved in:", out)
handler1 = DataHandlerFactory.get(
out,
read_only=True,
eval="pred",
)
handler2 = DataHandlerFactory.get(
msk_path,
read_only=True,
eval="label",
)
print("Metric computation with mask:", msk_path)
dice_score = versus_one(fct=dice, input_img=handler1.load(handler1.images[0])[0], target_img=handler2.load(handler2.images[0])[0], num_classes=num_classes)
print("Metric result:", dice_score)
#---------------------------------------------------------------------------
if __name__=='__main__':
# methods names
valid_names = {
"seg": pred_seg,
"seg_eval": pred_seg_eval,
"seg_multiple": pred_multiple,
"seg_single": pred_single,
"seg_eval_single": pred_seg_eval_single,
# "seg_patch": pred_seg_patch,
# "seg_patch_multi": pred_seg_patch_multi,
# "single": pred_single,
# "triplet": main_triplet,
# "arcface": main_arcface,
# "unet_triplet": main_unet_triplet,
# "cotrain": main_cotrain,
# "cotrain_and_single": main_cotrain_and_single
}
# parser
parser = argparse.ArgumentParser(description="Main training file.")
parser.add_argument("-n", "--name", type=str, default="seg",
help="Name of the tested method. Valid names: {}".format(valid_names.keys()))
parser.add_argument("-l", "--log", type=str, nargs='+',required=True,
help="Path of the builder directory/directiories. You can pass several paths to make a prediction using several models.")
parser.add_argument("-i", "--path_in","--dir_in",dest="path_in",type=str,required=True,
help="Path to the input image collection")
parser.add_argument("-o", "--path_out","--dir_out",dest="path_out", type=str,required=True,
help="Path to the output prediction collection")
parser.add_argument("-a", "--path_lab","--dir_lab",dest="path_lab", type=str, default=None,
help="Path to the input label collection")
parser.add_argument("-e", "--eval_only", default=False, action='store_true', dest='eval_only',
help="Do only the evaluation and skip the prediction (predictions must have been done already.)")
parser.add_argument("--skip_preprocessing", default=False, action='store_true',dest="skip_preprocessing",
help="(default=False) Skip preprocessing, it assume the preprocessing has already be done and can crash otherwise")
args = parser.parse_args()
if isinstance(args.log,list) and len(args.log)==1:
args.log = args.log[0]
# run the method
assert args.name in valid_names.keys(), "[Error] Name of the method must be one of {}".format(valid_names.keys())
if args.log is None:
valid_names[args.name].show(run=True)
else:
if args.name=="seg_eval":
valid_names[args.name](args.log,
args.path_in,
args.path_out,
args.path_lab,
args.eval_only,
skip_preprocessing=args.skip_preprocessing)
elif args.name=="seg_eval_single":
valid_names[args.name](args.log,
args.path_in,
args.path_out,
args.path_lab,
skip_preprocessing=args.skip_preprocessing)
else:
valid_names[args.name](args.log,
args.path_in,
args.path_out,
skip_preprocessing=args.skip_preprocessing)
#---------------------------------------------------------------------------