Inference¶
dplabtools offers a simple class for performing inference on WSIs using trained PyTorch models.
The inference process is integrated with other classes included in the package: patch location computing,
patch extraction and heatmap generation.
WSIInference class features:
Support for segmentation and classification models.
Support for single and multi resolution patches using dedicated Datasets.
Support for multi class model output.
Configurable GPU/CPU processing.
Inference output integrated with the Heatmap class.
Ability to save results as images, also containing embedded resolution information.
Parallelization based on PyTorch DataLoaders.
Basic usage¶
Assuming that the variables dataset, model and classifier exist and represent objects:
from dplabtools.slides.processing import WSIInference
inference = WSIInference(
model=model,
classifier=classifier,
level_or_minsize=0,
num_classes=3,
num_workers=12,
batch_size=256,
)
inference.process_dataset(dataset)
Full example of inference process¶
The full inference process consists of the following steps:
WSI mask generation
Calculating patch location
Creating WSI dataset
Initializing PyTorch model/classifier
Creating WSIInference object
Processing WSI dataset
from dplabtools.slides.processing import WSITissueMask, WSIDataset, WSIInference
from dplabtools.slides.patches import WholeImageGridPatches
wsi_file = "/tmp/wsi1.svs"
# Step 1
mask = WSITissueMask(wsi_file=wsi_file, level_or_minsize=2)
# Step 2
patches = WholeImageGridPatches(wsi_file=wsi_file, mask_data=mask.array)
# Step 3
dataset = WSIDataset(patches=patches)
# Step 4
model = get_model()
classifier = get_classifier()
# Step 5
inference = WSIInference(model=model, classifier=classifier, level_or_minsize=0, num_classes=2)
# Step 6
inference.process_dataset(dataset)
Note
When processing multiple WSIs in one go, to adhere to performance, steps 4 and 5 should be performed only once (rather than executed separately for each WSI).
Model initialization checklist¶
PyTorch model must be properly initialized before passing it to
WSIInference, this could be accomplished using a simple convenience functionget_model:import torch from mymodels import MyPyTorchModel MODEL_PATH = "/tmp/modelfile.pth" def get_model(): model = MyPyTorchModel() model.load_state_dict(torch.load(MODEL_PATH)) ... return model
Some PyTorch models must be set in evaluation mode when running the inference, this should be set inside
get_modelby callingeval():def get_model(): model = MyPyTorchModel() model.load_state_dict(torch.load(MODEL_PATH)) model.eval() ... return model
When using CUDA processing (
WSIInferencedefault mode), the model should be loaded into GPU memory insideget_modelby callingcuda():def get_model(): model = MyPyTorchModel() model.load_state_dict(torch.load(MODEL_PATH)) model.cuda() ... return model
When using CPU processing
WSIInferencemust first be initialized withuse_cuda=Falseand then model weights must be loaded withmap_location="cpu":def get_model(): model = MyPyTorchModel() model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) ... return model
Note
Error: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same indicates that
there is a discrepancy between how the model was initialized (CUDA/CPU) inside get_model and how WSIInference
was created (use_cuda=True|False).
Classifier initialization¶
Classifier is an optional concept representing an additional processing layer for a model output. Classifier can be
represented by any callable capable of returning torch.Tensor output. This includes PyTorch models
as well as native Python functions. The most basic classifier would perform the softmax transformation on the model
output, wrapped in get_classifier function:
import torch
def classifier_fn(result):
probabilities = torch.nn.functional.softmax(result, dim=1)
return probabilities
def get_classifier():
return classifier_fn
classifier = get_classifier()
In cases the when the classifier is not desirable its value should be set to None.
Class details¶
- class dplabtools.slides.processing.WSIInference(...)¶
Class for running inference on WSIs using PyTorch models.
- Parameters:
model (callable) – PyTorch model properly initialized.
classifier (callable) – PyTorch model of function capable of processing model’s output.
level_or_minsize (int) – WSI level or minimal desired size (in pixels) of inference output array.
num_classes (int) – Number of classes present in model output.
num_workers (int) – Number of worker processes used in data loading.
batch_size (int) – Number of samples per batch to load into GPU.
use_cuda (bool, default=True) – Declaration whether model will be using CUDA/GPU for processing or not, False will indicate pure CPU processing.
seed (int or float, optional) – Custom seed for random number generators.
- process_dataset(wsi_dataset, save_outputs_dir=None)¶
Compute the model/classifier output values for the whole WSI dataset.
- Parameters:
wsi_dataset (WsiDataset or WsiMultiResDataset) – Dataset representing patches from one WSI.
save_outputs_dir (str, optional) – Directory for saving model/classifier outputs as compressed NumPy arrays (using key “data”), one array file per one processed patch. This feature should only be used for troubleshooting inference problems.
- save_class_array(class_index, array_file)¶
Save the probabilities for one class as a compressed NumPy array.
- Parameters:
class_index (int) – Index of class to be saved.
array_file (str) – File name or path for saving the NPZ file.
- save_class_png(class_index, png_file)¶
Save the probabilities as a PNG image.
- Parameters:
class_index (int) – Index of class to be saved.
png_file (str) – File name or path for saving the PNG file.
- save_class_tif(class_index, tif_file, wsi_file, downsample_factor=None, jpeg_compression=True)¶
Save the probabilities as a TIF image with embedded resolution information.
- Parameters:
class_index (int) – Index of the class to be saved.
tif_file (str) – File name or path for saving the PNG file.
wsi_file (str) – WSI file name or path used in the inference (for extracting resolution information).
downsample_factor (float, optional) – Downsample factor used for resolution information. If not provided, then the value will be determined based on
level_or_minsizeinference parameter.jpeg_compression (bool, default=True) – Whether internal JPEG compression should be used or not.
- save_classes_array(array_file)¶
Save the probabilities for all classes as a compressed NumPy array.
- Parameters:
array_file (str) – File name or path for saving the NPZ file.
- classmethod set_interpolation_method(interpolation_method)¶
Set the interpolation method used for patch resizing on the output of segmentation models.
- Parameters:
interpolation_method (cv2 enum, default=cv2.INTER_LINEAR) – Interpolation method to be used. Available methods: https://docs.opencv.org/4.9.0/da/d54/group__imgproc__transform.html
- property classes_array¶
Return the probabilities array for all classes.
- property interpolation_method¶
Return the interpolation_method value set by
set_interpolation_method(default=cv2.INTER_LINEAR).
- property torch_device¶
Return the torch device name.
See also
Additional configuration¶
By default the module storing the WSIInference class sets cudnn.benchmark = True. In cases when this setting is
not desirable, its value can be reverted in the following way:
from dplabtools.slides.processing.inference import cudnn
cudnn.benchmark = False