Inference

dplabtools offers a simple class for performing inference on WSIs using trained PyTorch models. The inference process is integrated with other classes included in the package: patch location computing, patch extraction and heatmap generation.

WSIInference class features:

  • Support for segmentation and classification models.

  • Support for single and multi resolution patches using dedicated Datasets.

  • Support for multi class model output.

  • Configurable GPU/CPU processing.

  • Inference output integrated with the Heatmap class.

  • Ability to save results as images, also containing embedded resolution information.

  • Parallelization based on PyTorch DataLoaders.

Basic usage

Assuming that the variables dataset, model and classifier exist and represent objects:

from dplabtools.slides.processing import WSIInference

inference = WSIInference(
    model=model,
    classifier=classifier,
    level_or_minsize=0,
    num_classes=3,
    num_workers=12,
    batch_size=256,
)
inference.process_dataset(dataset)

Full example of inference process

The full inference process consists of the following steps:

  1. WSI mask generation

  2. Calculating patch location

  3. Creating WSI dataset

  4. Initializing PyTorch model/classifier

  5. Creating WSIInference object

  6. Processing WSI dataset

from dplabtools.slides.processing import WSITissueMask, WSIDataset, WSIInference
from dplabtools.slides.patches import WholeImageGridPatches

wsi_file = "/tmp/wsi1.svs"

# Step 1
mask = WSITissueMask(wsi_file=wsi_file, level_or_minsize=2)

# Step 2
patches = WholeImageGridPatches(wsi_file=wsi_file, mask_data=mask.array)

# Step 3
dataset = WSIDataset(patches=patches)

# Step 4
model = get_model()
classifier = get_classifier()

# Step 5
inference = WSIInference(model=model, classifier=classifier, level_or_minsize=0, num_classes=2)

# Step 6
inference.process_dataset(dataset)

Note

When processing multiple WSIs in one go, to adhere to performance, steps 4 and 5 should be performed only once (rather than executed separately for each WSI).

Model initialization checklist

  • PyTorch model must be properly initialized before passing it to WSIInference, this could be accomplished using a simple convenience function get_model:

    import torch
    
    from mymodels import MyPyTorchModel
    
    MODEL_PATH = "/tmp/modelfile.pth"
    
    def get_model():
        model = MyPyTorchModel()
        model.load_state_dict(torch.load(MODEL_PATH))
        ...
        return model
    
  • Some PyTorch models must be set in evaluation mode when running the inference, this should be set inside get_model by calling eval():

    def get_model():
        model = MyPyTorchModel()
        model.load_state_dict(torch.load(MODEL_PATH))
        model.eval()
        ...
        return model
    
  • When using CUDA processing (WSIInference default mode), the model should be loaded into GPU memory inside get_model by calling cuda():

    def get_model():
        model = MyPyTorchModel()
        model.load_state_dict(torch.load(MODEL_PATH))
        model.cuda()
        ...
        return model
    
  • When using CPU processing WSIInference must first be initialized with use_cuda=False and then model weights must be loaded with map_location="cpu":

    def get_model():
        model = MyPyTorchModel()
        model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
        ...
        return model
    

Note

Error: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same indicates that there is a discrepancy between how the model was initialized (CUDA/CPU) inside get_model and how WSIInference was created (use_cuda=True|False).

Classifier initialization

Classifier is an optional concept representing an additional processing layer for a model output. Classifier can be represented by any callable capable of returning torch.Tensor output. This includes PyTorch models as well as native Python functions. The most basic classifier would perform the softmax transformation on the model output, wrapped in get_classifier function:

import torch

def classifier_fn(result):
    probabilities = torch.nn.functional.softmax(result, dim=1)
    return probabilities

def get_classifier():
    return classifier_fn

classifier = get_classifier()

In cases the when the classifier is not desirable its value should be set to None.

Class details

class dplabtools.slides.processing.WSIInference(...)

Class for running inference on WSIs using PyTorch models.

Parameters:
  • model (callable) – PyTorch model properly initialized.

  • classifier (callable) – PyTorch model of function capable of processing model’s output.

  • level_or_minsize (int) – WSI level or minimal desired size (in pixels) of inference output array.

  • num_classes (int) – Number of classes present in model output.

  • num_workers (int) – Number of worker processes used in data loading.

  • batch_size (int) – Number of samples per batch to load into GPU.

  • use_cuda (bool, default=True) – Declaration whether model will be using CUDA/GPU for processing or not, False will indicate pure CPU processing.

  • seed (int or float, optional) – Custom seed for random number generators.

process_dataset(wsi_dataset, save_outputs_dir=None)

Compute the model/classifier output values for the whole WSI dataset.

Parameters:
  • wsi_dataset (WsiDataset or WsiMultiResDataset) – Dataset representing patches from one WSI.

  • save_outputs_dir (str, optional) – Directory for saving model/classifier outputs as compressed NumPy arrays (using key “data”), one array file per one processed patch. This feature should only be used for troubleshooting inference problems.

save_class_array(class_index, array_file)

Save the probabilities for one class as a compressed NumPy array.

Parameters:
  • class_index (int) – Index of class to be saved.

  • array_file (str) – File name or path for saving the NPZ file.

save_class_png(class_index, png_file)

Save the probabilities as a PNG image.

Parameters:
  • class_index (int) – Index of class to be saved.

  • png_file (str) – File name or path for saving the PNG file.

save_class_tif(class_index, tif_file, wsi_file, downsample_factor=None, jpeg_compression=True)

Save the probabilities as a TIF image with embedded resolution information.

Parameters:
  • class_index (int) – Index of the class to be saved.

  • tif_file (str) – File name or path for saving the PNG file.

  • wsi_file (str) – WSI file name or path used in the inference (for extracting resolution information).

  • downsample_factor (float, optional) – Downsample factor used for resolution information. If not provided, then the value will be determined based on level_or_minsize inference parameter.

  • jpeg_compression (bool, default=True) – Whether internal JPEG compression should be used or not.

save_classes_array(array_file)

Save the probabilities for all classes as a compressed NumPy array.

Parameters:

array_file (str) – File name or path for saving the NPZ file.

classmethod set_interpolation_method(interpolation_method)

Set the interpolation method used for patch resizing on the output of segmentation models.

Parameters:

interpolation_method (cv2 enum, default=cv2.INTER_LINEAR) – Interpolation method to be used. Available methods: https://docs.opencv.org/4.9.0/da/d54/group__imgproc__transform.html

property classes_array

Return the probabilities array for all classes.

property interpolation_method

Return the interpolation_method value set by set_interpolation_method (default=cv2.INTER_LINEAR).

property torch_device

Return the torch device name.

See also

level_or_minsize

Additional configuration

By default the module storing the WSIInference class sets cudnn.benchmark = True. In cases when this setting is not desirable, its value can be reverted in the following way:

from dplabtools.slides.processing.inference import cudnn
cudnn.benchmark = False