Chest X-Ray abnormality classification#

In the following, we are training a model that detects the following abnormalities from chest X-Rays: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, Fibrosis, Hernia, Infiltration, Mass, Nodule, Pleural_Thickening, Pneumonia, Pneumothorax.

For this, we will use the NIH Chest X-ray Dataset of 14 Common Thorax Disease Categories [1,2]:

[1] Xiaosong Wang, Yifan Peng, Le Lu, Zhiyong Lu, Mohammadhadi Bagheri, Ronald Summers,ChestX-ray8: Hospital-scale Chest X-ray Database and Benchmarks on Weakly-Supervised

[2] Classification and Localization of Common Thorax Diseases, IEEE CVPR, pp. 3462-3471, 2017 Hoo-chang Shin, Kirk Roberts, Le Lu, Dina Demner-Fushman, Jianhua Yao, Ronald M. Summers, Learning to Read Chest X-Rays: Recurrent Neural Cascade Model for Automated Image Annotation, IEEE CVPR, pp. 2497-2506, 2016

For training, we need pairs of images and their associated findings, indicating if each abnormality is present or not.

The images are stored in dataset_path. The findings are stored in the table.

Let’s start by extracting the data and importing the necessary libraries.

!bash /data/horse/ws/lazi257c-come2data_workshop/create_data_copy_radiology.sh
from pathlib import Path
import os
import math
import random
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torchsummary import summary
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score

Data Exploration 1#

Now, let’s visualize the table that contains the abnormality findings associated to the images.

#dataset_path = Path("/data/horse/ws/lazi257c-come2data_workshop/radiology_session/chest_x_ray")
dataset_path = Path(f"/tmp/{os.environ["USER"]}.radiology/chest_x_ray")
table_path = dataset_path / "Data_Entry_2017.csv"

image_paths = list((dataset_path / 'preprocessed').rglob("*.png"))

path_dict = {p.name: p for p in image_paths}

table_df = pd.read_csv(table_path)

table_df["image_path"] = table_df["Image Index"].map(path_dict)

# Split labels into lists
table_df["Finding Labels"] = table_df["Finding Labels"].str.split("|")

# Collect all unique disease labels (excluding "No Finding")
all_labels = sorted(
    set(l for sublist in table_df["Finding Labels"] for l in sublist if l != "No Finding")
)

# Create one column per disease with 0/1
for label in all_labels:
    table_df[label] = table_df["Finding Labels"].apply(lambda x: 1 if label in x else 0)

cols_to_show = [
    "Image Index", "Patient ID",
    "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
    "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
    "Pleural_Thickening", "Pneumonia", "Pneumothorax"
]
print("This is what the table looks like")
table_df[cols_to_show].head()

Data Exploration 2#

Let’s show how many patients and images per abnormality class are present.

disease_cols = [
    "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
    "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
    "Pleural_Thickening", "Pneumonia", "Pneumothorax"
]

# Collapse per patient (max across their images → whether patient ever had that finding)
per_patient = table_df.groupby("Patient ID")[disease_cols].max()

# Count patients and images per disease
patient_counts = per_patient.sum().astype(int)
image_counts   = table_df[disease_cols].sum().astype(int)

# Combine into a DataFrame for sorting
summary_df = pd.DataFrame({
    "patients": patient_counts,
    "images": image_counts
}).sort_values("patients", ascending=False)

# Pretty print
for disease, row in summary_df.iterrows():
    print(f"Number of patients (images) with {disease:<20}: {row['patients']} ({row['images']})")

Data Exploration 3#

Let’s show 5 random images and their associated finding(s).

# Pick 3 random samples that have valid images
samples = table_df.sample(5, random_state=222)

plt.figure(figsize=(12, 4))
for i, (_, row) in enumerate(samples.iterrows()):
    img = Image.open(dataset_path / "preprocessed" / row["Image Index"]).convert("L")
    findings = [d for d in disease_cols if row[d] == 1]
    title = ", ".join(findings) if findings else "No Finding"

    plt.subplot(1, 5, i+1)
    plt.imshow(img, cmap="gray")
    plt.axis("off")
    plt.title(title, fontsize=9)

plt.tight_layout()
plt.show()

Data Splitting#

Now, let’s split the data into a training set, validation set and test set. Each patient is either in the training set, validation set or test set to avoid data leakage.

  • Training set: images used to teach the model to recognize patterns.

  • Validation set: images checked during training to see how well the model is learning.

  • Test set: images kept aside until the end to measure how well the model performs on unseen cases.

import numpy as np
# ==== SPLIT: patient-level (80/20 train/val) ====
rng = np.random.default_rng(222)
patients = np.array(table_df["Image Index"].unique())
rng.shuffle(patients)

n = len(patients)
n_train = int(0.8 * n)
n_val   = int(0.1 * n)
# test = rest

train_patients = set(patients[:n_train])
val_patients   = set(patients[n_train:n_train + n_val])
test_patients  = set(patients[n_train + n_val:])

train_df = table_df[table_df["Image Index"].isin(train_patients)].reset_index(drop=True)
val_df   = table_df[table_df["Image Index"].isin(val_patients)].reset_index(drop=True)
test_df  = table_df[table_df["Image Index"].isin(test_patients)].reset_index(drop=True)

print(f"Patients: total={n} | train={len(train_patients)} | val={len(val_patients)} | test={len(test_patients)}")
print(f"Images: train={len(train_df)} | val={len(val_df)} | test={len(test_df)}")

Dataset and Dataloaders#

We define a custom CXRDataset for loading chest X-ray images and their multi-label annotations.

  • Images are converted to grayscale and then expanded to 3 channels for ResNet compatibility.

  • Transforms normalize images to ImageNet statistics (no augmentation).

  • DataLoaders are created for train/val/test splits

# ==== DATASET ====
class CXRDataset(Dataset):
    def __init__(self, df: pd.DataFrame, all_labels, transform=None):
        self.df = df
        self.all_labels = all_labels
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row["image_path"]).convert("L")  # grayscale
        if self.transform:
            img = self.transform(img)
        # Labels as float tensor (multi-label)
        y = torch.tensor(row[self.all_labels].values.astype(np.float32))
        return img, y

# Transforms: no augmentation; just make 3-channel + normalize to ImageNet
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # ResNet expects 3 channels
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225]),
])

train_ds = CXRDataset(train_df, all_labels, transform=transform)
val_ds   = CXRDataset(val_df,   all_labels, transform=transform)
test_ds   = CXRDataset(test_df,   all_labels, transform=transform)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=2, pin_memory=True)
test_loader   = DataLoader(test_ds,   batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

Model Loading#

  • Loads a ResNet-18 model (a well-known image classifier).

  • Replaces its last layer so it can predict our 14 abnormality findings.

  • Moves the model to the GPU (if available) for faster training.

  • Defines the loss function (how the model measures its discrepancy between predictions and true labels) and the optimizer (how it learns from the true labels).

  • Finally, shows a layer-by-layer summary of the model architecture and the number of parameters it will train.

from torchview import draw_graph

# ==== MODEL ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, len(all_labels))  # multi-label logits
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
warmup_steps = 200
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: min(1.0, float(step + 1) / warmup_steps),
)

summary(model, input_size=(3, 224, 224))

Model Evaluation & Training#

This cell handles two key parts:

Evaluation

  • Tests the model on validation data.

  • Reports how well it separates positive vs. negative cases using the AUROC metric.

Training

  • Trains step by step and checks progress every few rounds.

  • Stops early if performance no longer improves (to avoid overfitting).

  • Saves the best version of the model.

  • Records results in a log file for later plotting.

AUROC is a metric commonly used for evaluating classification performance

# ==== METRICS ====
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_logits = []
    all_targets = []
    running_val_loss = 0.0
    n_batches = 0
    for imgs, ys in loader:
        imgs = imgs.to(device, non_blocking=True)
        ys   = ys.to(device, non_blocking=True)
        logits = model(imgs)
        loss = criterion(logits, ys)
        running_val_loss += loss.item()
        n_batches += 1
        all_logits.append(torch.sigmoid(logits).detach().cpu().numpy())
        all_targets.append(ys.detach().cpu().numpy())
    val_loss = running_val_loss / max(1, n_batches)
    y_prob = np.concatenate(all_logits, axis=0)
    y_true = np.concatenate(all_targets, axis=0)

    # Per-class AUROC with safe handling for degenerate classes
    aurocs = []
    for ci in range(y_true.shape[1]):
        y_c = y_true[:, ci]
        p_c = y_prob[:, ci]
        # skip if only one class present
        if len(np.unique(y_c)) < 2:
            aurocs.append(np.nan)
            continue
        try:
            auc = roc_auc_score(y_c, p_c)
        except Exception:
            auc = np.nan
        aurocs.append(auc)
    mean_auroc = np.nanmean(aurocs)  # mean over classes that have both labels
    return val_loss, mean_auroc, aurocs

# ==== TRAIN LOOP (step-based) ====
max_steps = 5000
validate_every = 100
patience = 5  # early stopping patience on val AUROC
best_val_auroc = -math.inf
since_improve = 0

history = []  # list of dicts: step, train_loss, val_loss, val_auroc
global_step = 0

model.train()
running_train_loss = 0.0
batches_since_log = 0

train_iter = iter(train_loader)

while global_step < max_steps:
    try:
        imgs, ys = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        imgs, ys = next(train_iter)

    imgs = imgs.to(device, non_blocking=True)
    ys   = ys.to(device, non_blocking=True)

    logits = model(imgs)
    loss = criterion(logits, ys)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    scheduler.step()

    running_train_loss += loss.item()
    batches_since_log += 1
    global_step += 1

    # Validation & logging
    if global_step % validate_every == 0 or global_step == max_steps:
        train_loss = running_train_loss / max(1, batches_since_log)
        running_train_loss = 0.0
        batches_since_log = 0

        val_loss, val_mean_auroc, _ = evaluate(model, val_loader)

        history.append({
            "step": global_step,
            "train_loss": float(train_loss),
            "val_loss": float(val_loss),
            "val_mean_auroc": float(val_mean_auroc),
        })
        print(f"[step {global_step:5d}] train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_mean_AUROC={val_mean_auroc:.4f}")

        # Early stopping on val AUROC
        if val_mean_auroc > best_val_auroc + 1e-6:
            best_val_auroc = val_mean_auroc
            since_improve = 0
            # save best
            torch.save({
                "model_state_dict": model.state_dict(),
                "all_labels": all_labels,
                "val_mean_auroc": best_val_auroc,
                "step": global_step,
            }, "best_resnet18_multilabel.pt")
        else:
            since_improve += 1
            if since_improve >= patience:
                print(f"Early stopping at step {global_step} (no improvement in {patience} validations).")
                break

# Save training log for plotting later
log_df = pd.DataFrame(history)
log_csv_path = "training_log.csv"
log_df.to_csv(log_csv_path, index=False)
print(f"Saved log to {log_csv_path}")

Training Visualization#

The figure shows how the model’s training loss, validation loss, and validation AUROC change over time.

Overfitting:

  • Training loss: keeps decreasing.

  • Validation loss: decreases at first, then starts increasing.

  • Validation AUROC (or accuracy/other metric): peaks and then declines.

  • Interpretation: The model is learning patterns specific to the training set (memorization), losing the ability to generalize.

Underfitting

  • Training loss: remains high or decreases slowly.

  • Validation loss: also high, often decreasing in parallel with training loss (not yet separated).

  • Validation AUROC/metric: remains low, may still be rising, has not plateaued.

  • Interpretation: The model is too simple (or hasn’t trained enough) to capture meaningful patterns.

Good fit

  • Validation loss: minimum reached.

  • Validation AUROC/metric: maximum reached (green dot).

  • Interpretation: The model can generalize to unseen data.

Did the model over/underfit or was the training successful?

log_df = pd.read_csv(log_csv_path)


# --- Plot Loss and AUROC in two columns ---
fig, axes = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

# --- Left plot: Training & Validation Loss ---
axes[0].plot(log_df["step"], log_df["train_loss"], label="Training loss")
axes[0].plot(log_df["step"], log_df["val_loss"], label="Validation loss")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training & Validation Loss")
axes[0].legend()

# --- Right plot: Validation AUROC ---
axes[1].plot(log_df["step"], log_df["val_mean_auroc"], label="Validation AUROC", color="C2")
best_idx = log_df["val_mean_auroc"].idxmax()
best_step = log_df.loc[best_idx, "step"]
best_auroc = log_df.loc[best_idx, "val_mean_auroc"]
axes[1].scatter(best_step, best_auroc, color="C2", zorder=5)
axes[1].annotate(f"Best AUROC = {best_auroc:.3f}\n(step {best_step})",
                 (best_step, best_auroc),
                 textcoords="offset points", xytext=(-20, -30),
                 ha="left", color="C2")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("AUROC")
axes[1].set_title("Validation AUROC")
axes[1].legend()

plt.tight_layout()
plt.savefig("loss_and_auroc.png")
plt.show()
plt.close()

print(f"Saved: 'loss_and_auroc.png'")

Model Evaluation on Test Set#

Now, we evaluate the trained model on the independent test set. This is the only way to know how well the model performs on unseen data. We also visualize the ROC curves, which show the trade-off between sensitivity and specificity. Sensitivity reflects how well the model finding abnormal cases, while specificity reflects how well it rules out healthy cases. These two measures are always in balance: improving sensitivity usually lowers specificity, and vice versa. The red dot marks a decision boundary where both sensitivity and specificity are reasonably high, representing a good compromise.

# ==== FINAL TEST EVAL & PLOTTING ====
import json
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

# Load best checkpoint (if available)
ckpt_path = "best_resnet18_multilabel.pt"
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
print(f"Loaded best checkpoint from step {ckpt.get('step', 'N/A')} with val_mean_auroc={ckpt.get('val_mean_auroc', None)}")

@torch.no_grad()
def predict_probs(model, loader):
    model.eval()
    probs_list = []
    targets_list = []
    for imgs, ys in loader:
        imgs = imgs.to(device, non_blocking=True)
        ys   = ys.to(device, non_blocking=True)
        logits = model(imgs)
        probs = torch.sigmoid(logits)
        probs_list.append(probs.cpu().numpy())
        targets_list.append(ys.cpu().numpy())
    return np.concatenate(probs_list, axis=0), np.concatenate(targets_list, axis=0)

y_prob_test, y_true_test = predict_probs(model, test_loader)
print(f"Test shapes: probs={y_prob_test.shape}, targets={y_true_test.shape}")

per_class_results = []
mean_aurocs = []

# --- Plot all ROC curves in one figure (2 rows × 7 cols) ---
n_classes = len(all_labels)
fig, axes = plt.subplots(2, 7, figsize=(20, 8), dpi=150)

for ci, cls in enumerate(all_labels):
    row, col = divmod(ci, 7)
    ax = axes[row, col]

    y_c = y_true_test[:, ci]
    p_c = y_prob_test[:, ci]

    # Skip degenerate classes
    if len(np.unique(y_c)) < 2:
        ax.set_title(f"{cls}\n(skipped)")
        ax.axis("off")
        continue

    fpr, tpr, thresholds = roc_curve(y_c, p_c)
    class_auc = auc(fpr, tpr)

    # Youden's J statistic
    J = tpr - fpr
    j_idx = int(np.argmax(J))
    thr_opt = thresholds[j_idx]
    tpr_opt = tpr[j_idx]
    fpr_opt = fpr[j_idx]

    # Plot ROC and Youden point
    ax.plot(fpr, tpr, label=f"AUC = {class_auc:.3f}")
    ax.plot([0, 1], [0, 1], linestyle="--", color="gray")
    ax.scatter([fpr_opt], [tpr_opt], s=30, marker="o", color="red")  # Youden point
    ax.set_title(cls, fontsize=10)
    ax.set_xlabel("1 – Specificity")
    ax.set_ylabel("Sensitivity")
    ax.legend(loc="lower right", fontsize=8)

    per_class_results.append({
        "class": cls,
        "auroc": float(class_auc),
        "opt_threshold": float(thr_opt),
        "opt_tpr": float(tpr_opt),
        "opt_fpr": float(fpr_opt),
    })
    mean_aurocs.append(float(class_auc))

plt.tight_layout()
plt.show()
out_path = "roc_all_classes.png"
plt.savefig(out_path)
plt.close()
print(f"Saved combined ROC plot: {out_path}")

# Mean AUROC across classes that had both labels
test_mean_auroc = float(np.nanmean(mean_aurocs)) if len(mean_aurocs) else float("nan")
print(f"TEST mean AUROC (across valid classes): {test_mean_auroc:.4f}")
summary_csv = "test_results_summary.csv"
summary_df.to_csv(summary_csv, index=False)
print(f"Saved per-class summary: {summary_csv}")

Explainability#

Now, we will apply an explainability method that should tell us why the model predicted the way it did. A simple and commonly used method is creating heatmaps using Grad-CAMs. Put simply, Grad-CAM highlights the areas in the X-ray image that contributed most to the model’s decision, showing us whether the model is focusing on clinically relevant regions.

We will visualize the heatmaps from five random pneumothorax cases. Do the heatmaps make any sense, i.e. do they highlight meaningful areas?

# ==== GRAD-CAM (Captum) — 5 random Pneumothorax test images (categorical titles via ROC threshold) ====
from captum.attr import LayerGradCam, LayerAttribution
import matplotlib.pyplot as plt
import numpy as np
import math

assert "Pneumothorax" in all_labels, "Pneumothorax not found in all_labels!"
cls_idx = all_labels.index("Pneumothorax")

# Build a lookup for optimal thresholds computed above (Youden's J)
thr_lookup = {r["class"]: r["opt_threshold"] for r in per_class_results if not np.isnan(r["opt_threshold"])}
thr_pneu = thr_lookup["Pneumothorax"]
# Select only true Pneumothorax cases from the test set
pneu_cases = test_df[test_df["Pneumothorax"] == 1]
n_show = min(5, len(pneu_cases))
samples = pneu_cases.sample(n_show, random_state=42).reset_index(drop=True)

# Grad-CAM setup on the last ResNet conv block
target_layer = model.layer4[-1]
gradcam = LayerGradCam(model, target_layer)
model.eval()

fig, axes = plt.subplots(1, n_show, figsize=(3.6*n_show, 4), dpi=150)

# Handle case when n_show == 1 (axes not iterable)
if n_show == 1:
    axes = [axes]

for i, row in samples.iterrows():
    img_path = row["image_path"]
    pil_img = Image.open(img_path).convert("L")
    rgb_for_overlay = pil_img.convert("RGB")
    x = transform(pil_img).unsqueeze(0).to(device)

    with torch.enable_grad():
        logits = model(x)
        prob = torch.sigmoid(logits)[0, cls_idx].item()
        attrs = gradcam.attribute(inputs=x, target=cls_idx)  # [1, C, h, w]
        cam = LayerAttribution.interpolate(attrs, x.shape[2:]).detach().cpu().numpy()
        cam = cam.mean(axis=1)[0]  # [H, W]

    # Normalize to [0,1]
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)

    # Binarize prediction using ROC-derived threshold
    pred_label = int(prob >= thr_pneu)
    true_label = int(row["Pneumothorax"])
    pred_text = "Positive" if pred_label == 1 else "Negative"
    label_text = "Positive" if true_label == 1 else "Negative"

    ax = axes[i]
    ax.imshow(rgb_for_overlay)
    im = ax.imshow(cam, cmap="jet", alpha=0.15)
    ax.axis("off")
    ax.set_title(f"Pneumothorax:\nModel: {pred_text}\nGround Truth: {label_text}", fontsize=9)

# Add a single colorbar for the heatmaps
cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])  # [left, bottom, width, height]
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.set_label("Grad-CAM importance\nLow → High", fontsize=9)

plt.tight_layout(rect=[0, 0, 0.9, 1])
out_path = "gradcam_pneumothorax_5_examples.png"
plt.show()
plt.savefig(out_path, bbox_inches="tight", pad_inches=0)
plt.close()
print(f"Saved 5-panel Grad-CAM figure with categorical predictions: {out_path}")

Optional: Training a Larger Model#

If you still have time, define a larger model, e.g. a resnet50, and train again. You can use the code from above. Observe, if the performance of the larger model changes. Do you know why?

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
# Now train the model again and evaluate it afterwards. You can copy the code from above.