CNNs for radiology images#

In this notebook we will apply our new CNN knowledge to classify medical images. For this, we use a chest X-ray dataset from MedMNIST. Here, we want to classify patients into healthy and pneumonia cases. It was originally derived from a larger pediatric chest X-ray dataset.

Goals:

  • load X-ray data

  • build and train CNN for pneumonia classification

  • evaluate the model

  • have a first look at explainability

from google.colab import ai
response = ai.generate_text("Tell me what AI you are and can you help me with coding problems?")
print(response)
Hello!

I am a large language model, trained by Google. My purpose is to assist with a wide range of tasks by processing information and generating human-like text based on the vast amount of data I've been trained on. I don't have personal experiences, consciousness, or feelings; I'm an AI designed to communicate and provide information.

And yes, absolutely! **I can definitely help you with coding problems.**

Here's how I can assist you:

*   **Explaining Concepts:** I can break down complex programming concepts, algorithms, data structures, and design patterns.
*   **Writing Code Snippets:** I can generate code snippets, functions, or even small programs in various programming languages (Python, JavaScript, Java, C++, Ruby, Go, etc.) based on your requirements.
*   **Debugging Assistance:** If you have an error, you can share your code and the error message, and I can help you identify the potential cause and suggest solutions.
*   **Code Review and Improvement:** I can look at your code and suggest ways to make it more efficient, readable, or adhere to best practices.
*   **Syntax and Language Help:** If you're stuck on how to perform a specific action in a language, I can provide the correct syntax and examples.
*   **Algorithm Design:** I can help you brainstorm different approaches to solve a problem and explain the pros and cons of various algorithms.
*   **Learning Resources:** I can point you towards documentation, tutorials, or other learning materials for specific languages or technologies.
*   **Framework and Library Usage:** I can explain how to use popular frameworks and libraries and provide examples.

**To get the best help, please provide:**

*   **The programming language** you're using.
*   **Your code snippet** (if applicable).
*   **The error message** (if applicable).
*   **What you're trying to achieve** or what the problem is.
*   **What you've tried so far.**

So, feel free to ask away! Just tell me what you're working on, and I'll do my best to assist you.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

1) Load MedMNIST datset#

We use the medmnist package, which provides programmatic access to MedMNIST cohorts.

!pip install medmnist # should take around 2-3 min
import medmnist
from medmnist import INFO
Requirement already satisfied: medmnist in /usr/local/lib/python3.12/dist-packages (3.0.2)
Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from medmnist) (2.0.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from medmnist) (2.2.2)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (from medmnist) (1.6.1)
Requirement already satisfied: scikit-image in /usr/local/lib/python3.12/dist-packages (from medmnist) (0.25.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from medmnist) (4.67.3)
Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from medmnist) (11.3.0)
Requirement already satisfied: fire in /usr/local/lib/python3.12/dist-packages (from medmnist) (0.7.1)
Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (from medmnist) (2.10.0+cu128)
Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from medmnist) (0.25.0+cu128)
Requirement already satisfied: termcolor in /usr/local/lib/python3.12/dist-packages (from fire->medmnist) (3.3.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->medmnist) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->medmnist) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->medmnist) (2025.3)
Requirement already satisfied: scipy>=1.11.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (1.16.3)
Requirement already satisfied: networkx>=3.0 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (3.6.1)
Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (2.37.2)
Requirement already satisfied: tifffile>=2022.8.12 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (2026.2.24)
Requirement already satisfied: packaging>=21 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (26.0)
Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (0.4)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn->medmnist) (1.5.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn->medmnist) (3.6.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (3.24.3)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (4.15.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (1.14.0)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (2025.3.0)
Requirement already satisfied: cuda-bindings==12.9.4 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (12.9.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (12.8.93)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (12.8.90)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (12.8.90)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (12.8.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (11.3.3.83)
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (10.3.9.90)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (11.7.3.90)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (12.5.8.93)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.4.5 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (3.4.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (12.8.90)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (12.8.93)
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (1.13.1.3)
Requirement already satisfied: triton==3.6.0 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (3.6.0)
Requirement already satisfied: cuda-pathfinder~=1.1 in /usr/local/lib/python3.12/dist-packages (from cuda-bindings==12.9.4->torch->medmnist) (1.4.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->medmnist) (1.17.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch->medmnist) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch->medmnist) (3.0.3)
data = "pneumoniamnist"

info = INFO[data]
DataClass = getattr(medmnist, info["python_class"])

print(info["task"])
print(info["label"])

transform = transforms.Compose([
    transforms.ToTensor()
])
binary-class
{'0': 'normal', '1': 'pneumonia'}

3) Train/validation/test split#

train_dataset = DataClass(split="train", transform=transform, download=True)
val_dataset   = DataClass(split="val", transform=transform, download=True)
test_dataset  = DataClass(split="test", transform=transform, download=True)

print(len(train_dataset), len(val_dataset), len(test_dataset))
4708 524 624

Lets have a look at some example images. As you can see the images are very pixelated (28x28 pixels). This is of course not the resolution used in clinical practice.

class_names = [info["label"][str(i)] for i in range(len(info["label"]))]

plt.figure(figsize=(10, 4))
for i in range(4):
    img, label = train_dataset[i+5]
    plt.subplot(2, 4, i + 1)
    plt.imshow(img.squeeze(), cmap="gray")
    plt.title(class_names[int(label)])
    plt.axis("off")
plt.tight_layout()
plt.show()
/tmp/ipykernel_1796/2609169082.py:8: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  plt.title(class_names[int(label)])
../_images/b758533b0928c628245cbe0b4f48d0e76c31cd6e2dfb3dd6f9d12d701fccf7d9.png
BATCH_SIZE = 64

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Exercise: Since we already discussed how to build CNNs and you also worked through first steps with deep learning, please fill out the empty code cells and carry out the respective tasks.

4) Define a CNN#

# add your CNN here

5) Training loop#

# add your training code here

6) Plot loss over epochs and evaluate the model on the test datset#

# add your plots and evaluations here

Once you’re done with the exercises you can have a look at the solutions.ipynb where we added a small example of the explainability of CNNs which will be important in the next session.