Train Your First Medical AI Model with MONAI

We briefly introduced MONAI in the previous Getting Started AI in Healthcare with MONAI post. Now, we will train our very first and simple medical AI model with MONAI Core. We can leverage the workflows in MONAI to quickly set up a robust training or evaluation program for research experiments.

The design principle of MONAI is to provide flexible and light APIs for users with varying expertise. All the core components are independent modules, which can be easily integrated into any existing PyTorch program. 

Model Training Pipeline with MONAI

Let’s get started our first model training pipeline with MONAI Core! We can split the whole pipeline to these small steps:

  1. Importing Python Modules
  2. Setting Environment Parameters
  3. Defining Transforms for Preprocessing
  4. Splitting & Preparing Dataset
  5. Defining Model Structure & Hyperparameters
  6. Training Model

1. Importing Python Modules

import os
import torch

from monai.apps import MedNISTDataset
from monai.data import DataLoader
from monai.networks.nets import DenseNet121
from monai.inferers import SimpleInferer
from monai.engines import SupervisedTrainer
from monai.utils import set_determinism
from monai import transforms as trf

We will use only torch and monai modules for this introduction MONAI tutorial. Some datasets and model structures like MedNIST and DenseNet comes built-in MONAI Core. We can use them implicitly by importing and modifying their parameters. We will dive into the details in upcoming sections.

2. Setting Environment Parameters

set_determinism(seed=42)
threads = 30
device = torch.device("cuda:0")

set_determinism method freezes the environment randomness for reproducibility. The threads variable is number of the available CPU threads for parallel processing. Lastly, we assigned our GPU to the device variable. We will use these GPU for memory allocation and calculation during the pipeline.

3. Defining MONAI Transforms for Preprocessing

Medical images require highly specialized methods for I/O, preprocessing, and augmentation. They are often in specialized formats with rich meta-information, and the data volumes are often high-dimensional. 

Transforms are one of the powerful feature of MONAI Core. MONAI ensures a powerful and flexible image transformations that facilitate user-friendly, reproducible, optimized medical data processing and manipulating pipelines. 

We can prefer using built-in transforms or creating our own custom transform method. We can line up whole transforms sequentially and create preprocessing or postprocessing pipelines.

transform = trf.Compose([
    trf.LoadImaged(keys="image"),
    trf.EnsureChannelFirstd(keys="image"),
    trf.ScaleIntensityd(keys="image"),
    trf.ToTensord(keys=["image", "label"])
])

Here, trf.Compose is used to chain together a list of transformations. The transformations will be applied in the order they are listed.  LoadImaged transformation loads image data from a specified path. The keys argument specifies which key in the input dictionary contains the path to the image file. In this case, the code expects the input dictionary to have a key named “image” with the associated value being the path to the image file.

Medical images can have different formats. Some might have the channel (like RGB channels in color images) as the last dimension, while others might have it as the first. EnsureChannelFirstd transformation ensures that the channel dimension is the first dimension. The keys argument specifies which data in the input dictionary this transformation should be applied to.

ScaleIntensityd transformation scales the intensity values of the image to be between 0 and 1. This is often done to normalize the data before feeding it into a neural network. Again, the keys argument specifies which data in the input dictionary this transformation should be applied to.

ToTensord transformation converts the specified data in the input dictionary to PyTorch tensors. This is necessary because deep learning frameworks like PyTorch expect data to be in tensor format. Here, both the “image” and “label” data in the input dictionary will be converted to tensors.

Why Transform Names End with Letter “d”?

In MONAI, the “d” at the end of transformation names stands for “dictionary”. MONAI provides two types of transformations:

  1. Array-based transformations: These transformations operate directly on arrays (like numpy arrays). They don’t expect any specific data structure.
  2. Dictionary-based transformations: These transformations expect the input data to be in the form of a dictionary. The keys of the dictionary are specified when setting up the transformation (as you’ve seen in the code you provided). The advantage of dictionary-based transformations is that they allow for more structured data handling, especially when dealing with multiple items like images and labels, or multiple modalities, etc.

By convention, dictionary-based transformations in MONAI have names that end with the letter “d” to distinguish them from array-based transformations. This naming convention helps users quickly identify the expected input format for a given transformation.

4. Splitting & Preparing Datasets

training_set, validation_set, test_set = [MedNISTDataset(
    root_dir=os.getcwd(), 
    transform=transform, 
    section=set_name, 
    num_workers=threads
) for set_name in ['training', 'validation', 'test']]

This code creates three separate datasets (for training, validation, and testing) from the MedNIST dataset, applies the specified transformations which we defined above, and assigns them to the respective variables. 

If the dataset does not found in the directory, it automatically downloads. The MedNIST dataset is already separated to training, validation and test set so we do not have to shuffle and split it ourselves.

training_loader = DataLoader(training_set, batch_size=512, shuffle=True, num_workers=threads)
validation_loader = DataLoader(validation_set, batch_size=512, shuffle=True, num_workers=threads)
test_loader = DataLoader(test_set, batch_size=512, shuffle=True, num_workers=threads)

This code creates data loaders for the training, validation, and test datasets using PyTorch’s DataLoader class. These data loaders can then be used in the training loop and evaluation phases of a deep learning workflow.

5. Defining Model Structure & Hyperparameters

You can define a model in only one line code with MONAI.

model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6).to(device)

This code initializes a DenseNet121 model created for 2D data and then moves it to our training device which we defined before. 

  • spatial_dims=2 parametersSpecifies that the data is 2D. This is important because medical imaging, which MONAI focuses on, can involve 2D, 3D, or even higher-dimensional data. 
  • in_channels=1 indicates that the input data has one channel. For grayscale medical images, this is typically the case.
  • out_channels=6 specifies that the output of the network has 6 channels. This could mean, for example, that there are 6 classes in a classification task.
  • .to(device) moves the initialized model to the specified device we defined above. This step is essential for ensuring that the computations during training or inference are performed on the desired hardware.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
loss = torch.nn.CrossEntropyLoss()
inferer = SimpleInferer()

This code initializes three essential components for training a deep learning model using PyTorch. First, it sets up the Adam optimizer with a learning rate of 1×10^−5 to adjust the model’s weights based on the gradients during training. 

Second, it defines the loss function as the CrossEntropyLoss, which is commonly used for classification tasks. Finally, it creates an instance of SimpleInferer, which is likely a utility for performing inference (or forward passes) on the model

MONAI Model Zoo

monai.networks.nets module contains various neural network architectures that are commonly used or adapted for medical imaging tasks. These architectures range from standard ones like DenseNet or UNet to more specialized architectures tailored for specific medical imaging tasks.

DenseNet121 is a model from DenseNet family, which is known for its densely connected layers. DenseNets are efficient in terms of parameter count and have shown good performance on various tasks. We used this model to keep this tutorial short but you can use different model or can create a new model from scratch. Look at the MONAI Documentation for the details.

6. Training Model

trainer = SupervisedTrainer(
    max_epochs=3,
    train_data_loader=training_loader,
    network=model,
    optimizer=optimizer,
    loss_function=loss,
    inferer=inferer,
    device=device,
)

trainer.run()

This code initializes a SupervisedTrainer, which is a utility for training supervised deep learning models. The trainer is configured with a maximum of 3 training epochs, using the previously defined training_loader for loading training data. 

The model as the neural network to be trained, the optimizer for weight updates, the loss function to compute training loss, the inferer for performing model inference, and the target device for computation. After configuring the trainer, the run() method is called to start the training process over the specified epochs.

Conclusion

MONAI is a specialized deep learning framework tailored for medical imaging, offering a comprehensive suite of tools and components that streamline the development and deployment of medical imaging models. Its modular and extensible design allows researchers and clinicians to build end-to-end training pipelines with ease, from data loading and preprocessing to model definition and training. 

Through the provided code snippets, we observed how MONAI facilitates the creation of a training pipeline, encompassing dataset preparation, model initialization, optimizer and loss function setup, and the actual training process shortly. We will prepare more complex tutorial in upcoming days.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *