Inference

dplabtools offers a simple class for performing inference on WSIs using trained PyTorch models. The inference process is integrated with other classes included in the package: patch location computing, patch extraction and heatmap generation.

WSIInference class features:

  • Support for segmentation and classification models.

  • Support for single and multi resolution patches using dedicated Datasets.

  • Support for multi class model output.

  • Configurable GPU/CPU processing.

  • Inference output integrated with the Heatmaps class.

  • Ability to save results as images, also containing embedded resolution information.

  • Parallelization based on PyTorch DataLoaders.

Basic usage

Assuming that the variables dataset, model and classifier exist and represent objects:

from dplabtools.slides.processing import WSIInference

inference = WSIInference(
    model=model,
    classifier=classifier,
    level_or_minsize=0,
    num_classes=3,
    num_workers=12,
    batch_size=256,
)
inference.process_dataset(dataset)

Full example of inference process

The full inference process consists of the following steps:

  1. WSI mask generation

  2. Calculating patch location

  3. Creating WSI dataset

  4. Initializing PyTorch model/classifier

  5. Creating WSIInference object

  6. Processing WSI dataset

  7. Handling inference output

from dplabtools.slides.processing import WSITissueMask, WSIDataset, WSIInference
from dplabtools.slides.patches import WholeImageGridPatches

wsi_file = "/tmp/wsi1.svs"

# Step 1
mask = WSITissueMask(wsi_file=wsi_file, level_or_minsize=2)

# Step 2
patches = WholeImageGridPatches(wsi_file=wsi_file, mask_data=mask.array)

# Step 3
dataset = WSIDataset(patches=patches)

# Step 4
model = get_model()
classifier = get_classifier()

# Step 5
inference = WSIInference(model=model, classifier=classifier, level_or_minsize=2,
                         num_classes=2, num_workers=6, batch_size=128)

# Step 6
inference.process_dataset(dataset)

# Step 7
print(inference.classes_array)

Note

When processing multiple WSIs in one go, to adhere to performance, steps 4 and 5 should be performed only once (rather than executed separately for each WSI).

Warning

When running inference with num_workers=0, the corresponding WSI dataset class must be created with zero_workers=True: Parameters common in all WSI dataset classes.

Model initialization checklist

  • PyTorch model must be properly initialized before passing it to WSIInference, this could be accomplished using a simple convenience function get_model:

    import torch
    
    from mymodels import MyPyTorchModel
    
    MODEL_PATH = "/tmp/modelfile.pth"
    
    def get_model():
        model = MyPyTorchModel()
        model.load_state_dict(torch.load(MODEL_PATH))
        ...
        return model
    

    Note

    model.eval() will be called automatically during the inference process.

  • When using CUDA processing (WSIInference default mode), the model should be loaded into GPU memory inside get_model by calling cuda():

    def get_model():
        model = MyPyTorchModel()
        model.load_state_dict(torch.load(MODEL_PATH))
        model.cuda()
        ...
        return model
    
  • When using CPU processing WSIInference must first be initialized with use_cuda=False and then model weights must be loaded with map_location="cpu":

    def get_model():
        model = MyPyTorchModel()
        model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
        ...
        return model
    
  • If data processed during inference requires prior transformations, those transformations should be specified in the corresponding WSI dataset class (parameter transform_fn): Parameters common in all WSI dataset classes.

Note

Error: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same indicates that there is a discrepancy between how the model was initialized (CUDA/CPU) inside get_model and how WSIInference was created (use_cuda=True|False).

Classifier initialization

Classifier is an optional concept representing an additional processing layer for a model output. Classifier can be represented by any callable capable of returning torch.Tensor output. This includes PyTorch models as well as native Python functions. The most basic classifier would perform the softmax transformation on the model output, wrapped in get_classifier function:

import torch

def classifier_fn(result):
    probabilities = torch.nn.functional.softmax(result, dim=1)
    return probabilities

def get_classifier():
    return classifier_fn

classifier = get_classifier()

Additional notes:

  • When the classifier is not desirable, its value should be set to None.

  • When the classifier needs to call .eval(), that call should be made manually.

Class details

class dplabtools.slides.processing.WSIInference(...)

Class for running inference on WSIs using PyTorch models.

Parameters:
  • model (callable) – PyTorch model properly initialized.

  • classifier (callable) – PyTorch model of function capable of processing model’s output.

  • level_or_minsize (int) – WSI level or minimal desired size (in pixels) of inference output array. This parameter controls the dimensions of the output array created during the inference process.

  • num_classes (int) – Number of classes present in model output.

  • num_workers (int) – Number of worker processes used in data loading (recommended value is greater than zero).

  • batch_size (int) – Number of samples per batch to load into GPU.

  • use_cuda (bool, default=True) – Declaration whether model will be using CUDA/GPU for processing or not, False will indicate pure CPU processing.

  • seed (int or float, optional) – Custom seed for random number generators.

classmethod set_interpolation_method(interpolation_method)

Set the interpolation method used for patch resizing on the output of segmentation models.

Parameters:

interpolation_method (cv2 enum, default=cv2.INTER_LINEAR) – Interpolation method to be used. Available methods: https://docs.opencv.org/4.13.0/da/d54/group__imgproc__transform.html

process_dataset(wsi_dataset, save_outputs_dir=None)

Compute the model/classifier output values for the whole WSI dataset.

Parameters:
  • wsi_dataset (WsiDataset or WsiMultiResDataset) – Dataset representing patches from one WSI.

  • save_outputs_dir (str, optional) – Directory for saving model/classifier outputs as compressed NumPy arrays (using key “data”), one array file per one processed patch. This feature should only be used for troubleshooting inference problems.

save_class_array(class_index, array_file)

Save the probabilities for one class as a compressed NumPy array.

Parameters:
  • class_index (int) – Index of class to be saved.

  • array_file (str) – File name or path for saving the NPZ file.

save_class_png(class_index, png_file)

Save the probabilities as a PNG image.

Parameters:
  • class_index (int) – Index of class to be saved.

  • png_file (str) – File name or path for saving the PNG file.

save_class_tif(class_index, tif_file, wsi_file, downsample_factor=None, jpeg_compression=True)

Save the probabilities as a TIF image with embedded resolution information.

Parameters:
  • class_index (int) – Index of the class to be saved.

  • tif_file (str) – File name or path for saving the PNG file.

  • wsi_file (str) – WSI file name or path used in the inference (for extracting resolution information).

  • downsample_factor (float, optional) – Downsample factor used for resolution information. If not provided, then the value will be determined based on level_or_minsize inference parameter.

  • jpeg_compression (bool, default=True) – Whether internal JPEG compression should be used or not.

save_classes_array(array_file)

Save the probabilities for all classes as a compressed NumPy array.

Parameters:

array_file (str) – File name or path for saving the NPZ file.

property classes_array

Return the probabilities array for all classes.

property interpolation_method

Return the interpolation_method value set by set_interpolation_method (default=cv2.INTER_LINEAR).

property model

Return the model used during inference.

property torch_device

Return the torch device name.

See also

level_or_minsize

Additional configuration

By default the module storing the WSIInference class sets cudnn.benchmark = True. In cases when this setting is not desirable, its value can be reverted in the following way:

from dplabtools.slides.processing.inference import cudnn
cudnn.benchmark = False