Open In Colab

7. 🧠 MISAW CNN-Temporal classification¢

Status

Filled notebook: View filled on Github Open filled In Collab
Author: Benjamin I. Fortuno

This notebook extends the previous approach of using CNN, adding temporal inference, to classify phases, steps, etc in data from the MISAW dataset, commonly used for computer vision and machine learning tasks. It covers:

  • Creating pytorch Dataset

  • Exploring options of CNN and Temporal inference

  • Visualize training and validation results

🧰 Importing Required Libraries¢

This section loads essential Python libraries like os, cv2, glob, and pandas which are needed for handling files, images, and data manipulation.

[1]:
# πŸ“¦ Install dependencies
!pip install pytorch-lightning -q
!pip install torchmetrics -q
!pip install pytorch-tcn -q
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.1/823.1 kB 17.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 4.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 46.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 28.9 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 36.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 1.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 6.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 10.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 9.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 5.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 32.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 962.5/962.5 kB 39.3 MB/s eta 0:00:00

[2]:
# Standard library
import os           # For file and directory operations
import glob         # For finding files using wildcard patterns
import json         # For saving annotations in JSON format
import shutil       # For copying images

# Third-party libraries
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split

# Torchvision
from torchvision import datasets, transforms

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# Torchmetrics
import torchmetrics

πŸ“₯ Downloading the DatasetΒΆ

Here we download the MISAW dataset in .zip format from a Dropbox link. This dataset contains videos and annotations for surgical workflow analysis.

[3]:
!wget -O MISAW.zip https://www.dropbox.com/scl/fi/psmlokrc5ms958ggqyv3u/MISAW.zip?rlkey=v91dz437npon5zz10olrbcqcd&st=54qvf31m&dl=0
--2025-06-03 13:04:18--  https://www.dropbox.com/scl/fi/psmlokrc5ms958ggqyv3u/MISAW.zip?rlkey=v91dz437npon5zz10olrbcqcd
Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline/Cq7JOyLE9fpIYpamVYde6Q-VNjrNGMt_EeOcqzCuMzm9D1nUQ8N81p2meJXY-Y2LWLorNUUEQ0u44ki-V1MspFIjOLRVKXBC5fg1cCAoaiilT43d9d1B1yQFXTNAkR6O2f20vqrZxfG1K1zFAs8Pj6Ce/file# [following]
--2025-06-03 13:04:19--  https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline/Cq7JOyLE9fpIYpamVYde6Q-VNjrNGMt_EeOcqzCuMzm9D1nUQ8N81p2meJXY-Y2LWLorNUUEQ0u44ki-V1MspFIjOLRVKXBC5fg1cCAoaiilT43d9d1B1yQFXTNAkR6O2f20vqrZxfG1K1zFAs8Pj6Ce/file
Resolving uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com (uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f
Connecting to uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com (uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /cd/0/inline2/Cq4N1ejIi9DSsvK0Pk8YN4NoMw9mWfkEyrqIqQHlXvH6juKwgI_5ofLj-sfvGEV5jKXK7kas-17IbEQWHxnPk4SEyMODvNMzSov4ikzJSNzBTKSSem09xDEq_hjf8qlVmb6OrED4dpTLZzPkTynt3R0sNqEjXVHd6s25XH35E8mEqweZ1SWWfF8YXl896zpFWFolkuZb1FSLZIXrhX4yexdh0QeJLIgON0EgfyUDUYQ_vRSowymlMtdsMuWik8ArkSTlsiTfKHO2oL0kv-fThrOOm1q124Ni8q46mU_gm_JOX70rEHYoA0TA4bpMq7NsLZWo7h72ihNqSkcB8MPl0LBQfMK6cT3VRB6zbUv9nELarlv8roo0dPjvpyIl_wNgl_E/file [following]
--2025-06-03 13:04:19--  https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline2/Cq4N1ejIi9DSsvK0Pk8YN4NoMw9mWfkEyrqIqQHlXvH6juKwgI_5ofLj-sfvGEV5jKXK7kas-17IbEQWHxnPk4SEyMODvNMzSov4ikzJSNzBTKSSem09xDEq_hjf8qlVmb6OrED4dpTLZzPkTynt3R0sNqEjXVHd6s25XH35E8mEqweZ1SWWfF8YXl896zpFWFolkuZb1FSLZIXrhX4yexdh0QeJLIgON0EgfyUDUYQ_vRSowymlMtdsMuWik8ArkSTlsiTfKHO2oL0kv-fThrOOm1q124Ni8q46mU_gm_JOX70rEHYoA0TA4bpMq7NsLZWo7h72ihNqSkcB8MPl0LBQfMK6cT3VRB6zbUv9nELarlv8roo0dPjvpyIl_wNgl_E/file
Reusing existing connection to uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com:443.
HTTP request sent, awaiting response... 200 OK
Length: 965450623 (921M) [application/zip]
Saving to: β€˜MISAW.zip’

MISAW.zip           100%[===================>] 920.72M  79.8MB/s    in 13s

2025-06-03 13:04:33 (70.2 MB/s) - β€˜MISAW.zip’ saved [965450623/965450623]

πŸ“¦ Extracting the DatasetΒΆ

This cell unzips the downloaded file to access the raw data.

[4]:
!unzip -qq MISAW.zip

🎞️ Frame Extraction from Videos¢

This section reads the videos and extracts every N-th frame (controlled by resample_rate). Each video gets its own subdirectory of frames, organized as:

MISAW/
  └── train/
        └── Frames/
              └── <video_id>/frame_0000.jpg

These images will be later paired with annotations.

[5]:
!sudo apt update
!sudo apt install -y ffmpeg

Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1,724 kB]
Get:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:9 https://r2u.stat.illinois.edu/ubuntu jammy/main amd64 Packages [2,735 kB]
Get:10 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 Packages [1,553 kB]
Hit:11 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:12 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:13 http://archive.ubuntu.com/ubuntu jammy-updates/restricted amd64 Packages [4,622 kB]
Get:14 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [8,992 kB]
Hit:15 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:16 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 Packages [3,290 kB]
Get:17 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [4,468 kB]
Get:18 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [2,979 kB]
Get:19 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,245 kB]
Fetched 32.0 MB in 3s (11.6 MB/s)
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
38 packages can be upgraded. Run 'apt list --upgradable' to see them.
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 38 not upgraded.
[6]:
!nvidia-smi
!ffmpeg -encoders | grep nvenc

Tue Jun  3 13:04:52 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   36C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-pocketsphinx --enable-librsvg --enable-libmfx --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared
  libavutil      56. 70.100 / 56. 70.100
  libavcodec     58.134.100 / 58.134.100
  libavformat    58. 76.100 / 58. 76.100
  libavdevice    58. 13.100 / 58. 13.100
  libavfilter     7.110.100 /  7.110.100
  libswscale      5.  9.100 /  5.  9.100
  libswresample   3.  9.100 /  3.  9.100
  libpostproc    55.  9.100 / 55.  9.100
 V....D h264_nvenc           NVIDIA NVENC H.264 encoder (codec h264)
 V..... nvenc                NVIDIA NVENC H.264 encoder (codec h264)
 V..... nvenc_h264           NVIDIA NVENC H.264 encoder (codec h264)
 V..... nvenc_hevc           NVIDIA NVENC hevc encoder (codec hevc)
 V....D hevc_nvenc           NVIDIA NVENC hevc encoder (codec hevc)
[ ]:
import os

video_folder = 'MISAW/train/Video'
frames_folder = 'MISAW/train/Frames'
os.makedirs(frames_folder, exist_ok=True)

resample_rate = 3  # Adjust as needed

for video_file in os.listdir(video_folder):
    if video_file.endswith(('.mp4', '.avi', '.mov')):
        video_path = os.path.join(video_folder, video_file)
        video_name = os.path.splitext(video_file)[0]
        output_folder = os.path.join(frames_folder, video_name)
        os.makedirs(output_folder, exist_ok=True)

        # Use ffmpeg with GPU support to extract frames
        output_pattern = os.path.join(output_folder, "frame_%04d.jpg")
        command = f"ffmpeg -hwaccel cuda -i '{video_path}' -vf 'select=not(mod(n\\,{resample_rate}))' -vsync vfr '{output_pattern}'"

        print(f"Extracting frames from {video_file}...")
        os.system(command)
        print(f"Frames saved to {output_folder}")

print("Done!")

Extracting frames from 2_5.mp4...
Frames saved to MISAW/train/Frames/2_5
Extracting frames from 6_2.mp4...
Frames saved to MISAW/train/Frames/6_2
Extracting frames from 3_2.mp4...
Frames saved to MISAW/train/Frames/3_2
Extracting frames from 2_1.mp4...
Frames saved to MISAW/train/Frames/2_1
Extracting frames from 3_3.mp4...
Frames saved to MISAW/train/Frames/3_3
Extracting frames from 6_4.mp4...
Frames saved to MISAW/train/Frames/6_4
Extracting frames from 1_4.mp4...

🧾 Parsing Annotations and Building Frame-Label Pairs¢

Here we parse the .txt annotation files (one per video) to extract surgical phases, then pair each extracted frame with a corresponding label. We also map textual phase labels to numeric IDs, which is important for training machine learning models.

[ ]:
# Folder where the annotation .txt files are stored (one per video)
annotation_folder = 'MISAW/train/Procedural decription'

# Folder where extracted frames from videos are stored (in subfolders per video)
frames_folder = 'MISAW/train/Frames'

# This is the same rate you used to extract frames from videos (e.g., every 120th frame)
# It must match or your labels won't align with the frames!

# --- Collect all unique phases first ---

# We'll store all unique phase labels (like "Idle", "Suturing", etc.) in this set
all_phases = set()

# Find all annotation files in the folder (e.g., 1_1_annotation.txt, 1_2_annotation.txt, etc.)
annotation_files = sorted(glob.glob(os.path.join(annotation_folder, '*_annotation.txt')))

# Loop through each annotation file and collect unique phase names
for anno_file in annotation_files:
    df = pd.read_csv(anno_file, sep='\t')  # Load the .txt file into a DataFrame
    all_phases.update(df['Step'].unique())  # Add unique phases to the set

# Create a dictionary to map each phase string to a unique integer ID
# Useful for training machine learning models
phase_to_id = {name: i for i, name in enumerate(sorted(all_phases))}

# Create the inverse mapping (ID to phase name) for visualization later
id_to_phase = {i: name for name, i in phase_to_id.items()}


# --- Build (frame_path, label) pairs ---

# This list will hold tuples like: ("path/to/frame.jpg", 2) β†’ (frame, label_id)
all_data = []

# Go through each annotation file (one per video)
for anno_file in annotation_files:
    # Extract the video ID from the filename (e.g., "1_1_annotation.txt" β†’ "1_1")
    video_id = os.path.basename(anno_file).replace('_annotation.txt', '')

    # Construct the path to the corresponding frame folder
    frame_dir = os.path.join(frames_folder, video_id)

    # Read the annotation file into a DataFrame
    df = pd.read_csv(anno_file, sep='\t')

    # Get the full list of phases (one per original video frame)
    phases = df['Step'].tolist()

    # Resample: only keep every N-th label (e.g., every 120th label)
    sampled_phases = phases[::resample_rate]

    # Convert phase strings to numeric labels using our earlier mapping
    sampled_ids = [phase_to_id[p] for p in sampled_phases]

    # Get the list of frame image paths, sorted so they match the order of labels
    frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))

    # Sanity check: if the number of frames doesn’t match the number of labels, skip this video
    if len(frame_paths) != len(sampled_ids):
        print(f"⚠️ Mismatch for {video_id}: {len(frame_paths)} frames vs {len(sampled_ids)} labels")
        continue

    # Add all (frame_path, label_id) pairs to our global list
    all_data.extend(zip(frame_paths, sampled_ids))


# Final print to confirm total number of samples loaded
print(f"βœ… Total samples: {len(all_data)}")

Visualize images and phasesΒΆ

[ ]:
import random
import matplotlib.pyplot as plt
from PIL import Image
from collections import defaultdict

# CONFIGURATION

# Number of sample images to display per phase (i.e., class/label)
samples_per_phase = 3

# Create a dictionary to group image paths by their label (phase ID)
# defaultdict(list) automatically creates an empty list for new keys
label_to_paths = defaultdict(list)

# all_data is assumed to be a list of (frame_path, label_id) pairs
# Here we organize all frame paths under their respective label IDs
for path, label_id in all_data:
    label_to_paths[label_id].append(path)

# Determine the number of unique phases (rows in the final plot)
n_rows = len(label_to_paths)

# Number of columns equals the number of samples we want to show per phase
n_cols = samples_per_phase

# Create a grid of subplots (n_rows x n_cols)
# figsize sets the overall size of the figure (in inches)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

# Add a main title above all the subplots
fig.suptitle("Random Samples per Phase", fontsize=18)

# Get all label IDs in sorted order (for consistent row display)
sorted_labels = sorted(label_to_paths.keys())

# Loop over each row (i.e., each unique label/phase)
for row, label_id in enumerate(sorted_labels):
    # Convert label ID back to its readable name using id_to_phase mapping
    label_name = id_to_phase[label_id]

    # Randomly sample a few images from this label
    # Ensure we don't try to sample more images than we actually have
    samples = random.sample(label_to_paths[label_id], min(samples_per_phase, len(label_to_paths[label_id])))

    # For each column (i.e., each sampled image for this label)
    for col in range(samples_per_phase):
        # Access the correct subplot (row x col)
        # If there's only one row, axes may be 1D
        ax = axes[row, col] if n_rows > 1 else axes[col]

        # Only plot if we have enough samples for this column
        if col < len(samples):
            # Open the image file using PIL
            img = Image.open(samples[col])

            # Display the image in the subplot
            ax.imshow(img)

            # Set the subplot title to the phase name
            ax.set_title(f"{label_name}", fontsize=10)

        # Remove axis ticks and labels for a cleaner look
        ax.axis('off')

# Adjust layout to prevent overlaps
plt.tight_layout()

# Adjust top spacing to make room for the main title
plt.subplots_adjust(top=0.95)

# Display the final grid of images
plt.show()

πŸ“ Converting to COCO-style FormatΒΆ

This section creates a new structure suitable for COCO-style datasets used in deep learning. It includes:

MISAW_coco/
  β”œβ”€β”€ Frames/
  β”‚     └── video1_frame_0000.jpg
  β”œβ”€β”€ annotations.json
  └── phase_to_id.json

This format is helpful for multi-label or object detection tasks.

[ ]:
# --- CONFIGURATION ---

# Folder where annotation .txt files are stored
annotation_folder = 'MISAW/train/Procedural decription'

# Folder where frames are stored (e.g., MISAW/train/Frames/1_1/frame_0000.jpg)
frames_folder = 'MISAW/train/Frames'

# Output root folder for the new COCO-style dataset
output_root = 'MISAW_coco'

# Knot-related phases to group into a single "Knot" phase
knot_phases = {'1 knot', '2 knot', '3 knot'}
idle_phases = {'Idle', 'Idle Step'}

# Inside this, all frames will be copied to one flat "Frames" folder
frames_out_root = os.path.join(output_root, 'Frames')
os.makedirs(frames_out_root, exist_ok=True)  # Create the folder if it doesn't exist

# --- STEP 1: COLLECT UNIQUE PHASES FROM ALL ANNOTATIONS ---

# Set to store all phase names (e.g., Idle, Suturing, etc.)
all_phases = set()

# Loop through each annotation file and collect phase names
for anno_file in glob.glob(f'{annotation_folder}/*_annotation.txt'):
    df = pd.read_csv(anno_file, sep='\t')  # Load tab-separated .txt file as a DataFrame
    df['Step'] = df['Step'].replace({p: "knot" for p in knot_phases})
    df['Step'] = df['Step'].replace({p: "idle" for p in idle_phases})
    all_phases.update(df['Step'].unique())  # Add unique phases to the set

# Create dictionaries to map phases to integer IDs and back
phase_to_id = {p: i for i, p in enumerate(sorted(all_phases))}  # e.g., "Suturing" β†’ 2
id_to_phase = {i: p for p, i in phase_to_id.items()}            # e.g., 2 β†’ "Suturing"

# --- STEP 2: BUILD JSON ENTRIES AND COPY FRAMES ---

entries = []  # This will hold all annotation entries

# Process each video annotation
for anno_file in glob.glob(f'{annotation_folder}/*_annotation.txt'):
    video_id = os.path.basename(anno_file).replace('_annotation.txt', '')  # e.g., "1_1"

    # Read annotation file
    df = pd.read_csv(anno_file, sep='\t')
    df['Step'] = df['Step'].replace({p: "knot" for p in knot_phases})
    df['Step'] = df['Step'].replace({p: "idle" for p in idle_phases})

    # Resample phase labels to match saved frames (e.g., take every 120th label)
    phases = df['Step'].tolist()[::resample_rate]

    # Get list of resampled frame image paths
    frame_dir = os.path.join(frames_folder, video_id)
    frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))

    # Ensure the number of frames and labels match
    if len(phases) != len(frame_paths):
        print(f"⚠️ Skipping {video_id}: mismatched {len(phases)} labels vs {len(frame_paths)} frames")
        continue

    # Loop over each frame-label pair
    for i, (frame_path, phase) in enumerate(zip(frame_paths, phases)):
        # Create a new frame name (e.g., "1_1_frame_0003.jpg")
        new_frame_name = f"{video_id}_frame_{i:04d}.jpg"

        # Full destination path for the copied frame
        new_frame_path = os.path.join(frames_out_root, new_frame_name)

        # Copy frame to the output folder
        shutil.copy(frame_path, new_frame_path)

        # Add a dictionary entry for this frame in COCO-style format
        entries.append({
            "video": video_id,
            "frame": new_frame_name,
            "path": new_frame_path,
            "label": phase,
            "label_id": phase_to_id[phase]
        })

# --- STEP 3: SAVE TO DISK ---

# Save all frame annotations to a JSON file
with open(os.path.join(output_root, 'annotations.json'), 'w') as f:
    json.dump(entries, f, indent=2)

# Save the phase-to-ID mapping for future use
with open(os.path.join(output_root, 'phase_to_id.json'), 'w') as f:
    json.dump(phase_to_id, f, indent=2)

# Final confirmation
print(f"βœ… COCO-style structure created in {output_root}/ with {len(entries)} annotated frames.")

Dataset and DataloaderΒΆ

[ ]:
# from torch.utils.data import Dataset
# from PIL import Image
# import torch

# class MISAWDataset(Dataset):
#     def __init__(self, annotations, sequence_length=15, transform=None):
#         self.annotations = annotations
#         self.sequence_length = sequence_length
#         self.transform = transform

#         # Ensure only complete sequences
#         self.num_sequences = len(self.annotations) // self.sequence_length

#     def __len__(self):
#         return self.num_sequences

#     def __getitem__(self, idx):
#         # Calculate the start and end index for this sequence
#         sequence_start = idx * self.sequence_length
#         sequence_end = sequence_start + self.sequence_length

#         # Collect the sequence of images
#         images = []
#         for i in range(sequence_start, sequence_end):
#             item = self.annotations[i]
#             image = Image.open(item['path']).convert('L')  # Convert to grayscale
#             if self.transform:
#                 image = self.transform(image)
#             images.append(image)

#         # Stack the images along the sequence dimension
#         images = torch.stack(images)  # Shape: (sequence_length, C, H, W)

#         # Use the label of the last frame in the sequence as the target
#         label = self.annotations[sequence_end - 1]['label_id']

#         return images, label

from torch.utils.data import Dataset
from PIL import Image
import torch

class MISAWDataset(Dataset):
    def __init__(self, annotations, sequence_length=8, stride=1, transform=None):
        self.annotations = annotations
        self.sequence_length = sequence_length
        self.stride = stride
        self.transform = transform

        # Calculate the number of possible sequences with overlap
        self.valid_indices = list(range(0, len(self.annotations) - self.sequence_length + 1, self.stride))

    def __len__(self):
        # Return the number of possible overlapping sequences
        return len(self.valid_indices)

    def __getitem__(self, idx):
        # Calculate the start index for this sequence
        start_idx = self.valid_indices[idx]
        end_idx = start_idx + self.sequence_length

        # Collect the sequence of images
        images = []
        for i in range(start_idx, end_idx):
            item = self.annotations[i]
            image = Image.open(item['path']).convert('L')  # Convert to grayscale
            if self.transform:
                image = self.transform(image)
            images.append(image)

        # Stack the images along the sequence dimension
        images = torch.stack(images)  # Shape: (sequence_length, C, H, W)

        # Use the label of the last frame in the sequence as the target
        label = self.annotations[end_idx - 1]['label_id']

        return images, label
[ ]:
import json
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from collections import defaultdict

# Load annotations from JSON
with open('MISAW_coco/annotations.json') as f:
    annotations = json.load(f)

# Create a list of video groups
video_groups = defaultdict(list)
for image_info in annotations:
    video_name = image_info["video"]
    video_groups[video_name].append(image_info)

# Split videos deterministically (80% train, 20% val)
all_videos = sorted(video_groups.keys())
split_idx = int(len(all_videos) * 0.8)
train_videos = all_videos[:split_idx]
val_videos = all_videos[split_idx:]

# Collect dataset annotations
train_annotations = [item for vid in train_videos for item in video_groups[vid]]
val_annotations = [item for vid in val_videos for item in video_groups[vid]]

print(f"Train videos: {len(train_videos)}, Val videos: {len(val_videos)}")
print(f"Train frames: {len(train_annotations)}, Val frames: {len(val_annotations)}")
[ ]:
from torchvision import transforms

# Image transformations for training (stronger augmentation)
train_transform = transforms.Compose([
    transforms.Grayscale(),                         # Ensures image is 1-channel
    transforms.RandomResizedCrop((128, 128), scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Image transformations for validation (minimal augmentation)
val_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create the train and validation datasets
SEQUENCE_LENGTH = 15
train_dataset = MISAWDataset(train_annotations, sequence_length=SEQUENCE_LENGTH, transform=train_transform)
val_dataset = MISAWDataset(val_annotations, sequence_length=SEQUENCE_LENGTH, transform=val_transform)

print(f"Train sequences: {len(train_dataset)}, Val sequences: {len(val_dataset)}")

[ ]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Quick sanity check
for images, labels in train_loader:
    print(f"Images shape: {images.shape}, Labels shape: {labels.shape}")
    break
[ ]:
# import matplotlib.pyplot as plt
# import torch

# # πŸ” Visualize examples of each digit from the training dataset
# classes = list(range(8))
# samples_per_class = {c: None for c in classes}

# # Extract one full sequence for each class from the training set
# for images, label in train_dataset:
#     # Convert the label to an integer if it is a tensor
#     if isinstance(label, torch.Tensor):
#         label = label.item()

#     # Use the label as a scalar, not a tensor
#     if samples_per_class[label] is None:
#         samples_per_class[label] = images  # Shape: (sequence_length, 1, H, W)
#     if all(v is not None for v in samples_per_class.values()):
#         break

# # Plot the sequences
# fig, axes = plt.subplots(3, SEQUENCE_LENGTH, figsize=(20, 6))  # 3 classes, sequence_length frames each
# for class_id in classes:
#     sequence = samples_per_class[class_id]  # Shape: (sequence_length, 1, H, W)
#     for frame_idx in range(sequence.shape[0]):
#         ax = axes[class_id, frame_idx]
#         # Remove the channel dimension for grayscale images
#         ax.imshow(sequence[frame_idx].squeeze(0).cpu().numpy(), cmap="gray")
#         ax.axis("off")
#         if frame_idx == 0:
#             ax.set_title(f"Class {class_id}")

# plt.show()

ResNet50 + TCNΒΆ

[ ]:

class LitResNet18TCN(LightningModule):
    def __init__(self, num_classes=5, sequence_length=15, tcn_channels=[256, 256, 256], kernel_size=4, dropout=0.5, lr=1e-5, pooling="mean", freeze_layers=True):
        super().__init__()
        self.save_hyperparameters()

        # Load pretrained ResNet-18
        weights = ResNet18_Weights.DEFAULT
        self.model = resnet18(weights=weights)

        # Modify input layer to accept 1-channel (grayscale) input
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Optionally freeze layers
        if freeze_layers:
            for name, param in self.model.named_parameters():
                if not name.startswith("layer4") and not name.startswith("fc"):
                    param.requires_grad = False

        # Extract feature dimension before FC
        self.feature_dim = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(self.feature_dim, self.feature_dim)
        )

        # TCN module
        self.tcn = TCN(
            num_inputs=self.feature_dim,
            num_channels=tcn_channels,
            kernel_size=kernel_size,
            dropout=dropout,
            input_shape='NCL',
            output_projection=num_classes,
            output_activation=None
        )

        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.lr = lr
        self.pooling = pooling

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)

        features = self.model(x)  # (B * T, feature_dim)
        features = features.view(B, T, -1).permute(0, 2, 1)  # (B, feature_dim, T)

        logits = self.tcn(features)  # (B, num_classes, T)

        if self.pooling == "mean":
            logits = logits.mean(dim=-1)  # (B, num_classes)
        elif self.pooling == "max":
            logits, _ = logits.max(dim=-1)  # (B, num_classes)
        else:
            raise ValueError(f"Invalid pooling method: {self.pooling}")

        return logits

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)

    def step(self, batch, stage="train"):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log(f"{stage}_loss", loss, on_step=False, on_epoch=True)
        self.log(f"{stage}_acc", acc, on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, stage="train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, stage="val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, stage="test")

ResNet50 + LSTMΒΆ

[ ]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from pytorch_lightning import LightningModule
from torchvision.models import resnet18, ResNet18_Weights

class LitResNet18LSTM(LightningModule):
    def __init__(self, num_classes=5, sequence_length=15, lstm_hidden_size=512, lstm_layers=2, dropout=0.5, lr=1e-4, pooling="mean", freeze_layers=True):
        super().__init__()
        self.save_hyperparameters()

        # Load pretrained ResNet-18
        weights = ResNet18_Weights.DEFAULT
        self.model = resnet18(weights=weights)

        # Modify input layer to accept 1-channel (grayscale) input
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Optionally freeze layers
        if freeze_layers:
            for name, param in self.model.named_parameters():
                if not name.startswith("layer4") and not name.startswith("fc"):
                    param.requires_grad = False

        # Extract feature dimension before FC
        self.feature_dim = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(self.feature_dim, self.feature_dim)
        )

        # LSTM module
        self.lstm = nn.LSTM(
            input_size=self.feature_dim,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False
        )

        # Final classification layer
        self.fc = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(lstm_hidden_size, num_classes)
        )

        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.lr = lr
        self.pooling = pooling

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)

        features = self.model(x)  # (B * T, feature_dim)
        features = features.view(B, T, -1)

        lstm_out, _ = self.lstm(features)

        if self.pooling == "mean":
            logits = lstm_out.mean(dim=1)
        elif self.pooling == "max":
            logits, _ = lstm_out.max(dim=1)
        else:
            raise ValueError(f"Invalid pooling method: {self.pooling}")

        logits = self.fc(logits)
        return logits

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
                "interval": "epoch",
                "frequency": 1
            }
        }

    def step(self, batch, stage="train"):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log(f"{stage}_loss", loss, on_step=False, on_epoch=True)
        self.log(f"{stage}_acc", acc, on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, stage="train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, stage="val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, stage="test")

TrainingΒΆ

[ ]:
#create directory tb_logs
!mkdir -p tb_logs/

%load_ext tensorboard
%tensorboard --logdir tb_logs/
[ ]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_tcn import TCN

# πŸš‚ Training Configuration
MODEL_NAME = "LitResNet50TCN"  # Change to LitResNet50LSTM if needed
EPOCHS = 50
CHECKPOINT_DIR = "checkpoints/"
LOG_DIR = "tb_logs"

# Select the model
model = LitResNet18TCN(sequence_length=15)  # or LitResNet50LSTM(sequence_length=8)

# Initialize TensorBoard Logger
logger = TensorBoardLogger(LOG_DIR, name=MODEL_NAME)

# Initialize EarlyStopping and ModelCheckpoint callbacks
early_stop_callback = EarlyStopping(
    monitor="val_acc",
    patience=10,
    mode="max",
    verbose=True
)

checkpoint_callback = ModelCheckpoint(
    monitor="val_acc",
    mode="max",
    save_top_k=1,
    verbose=True,
    dirpath=CHECKPOINT_DIR,
    filename=MODEL_NAME + "-{epoch:02d}-{val_acc:.4f}"
)

# Initialize the Trainer
trainer = Trainer(
    max_epochs=EPOCHS,
    accelerator="auto",
    devices="auto",
    log_every_n_steps=1,  # βœ… log at every step
    logger=logger,  # Log to TensorBoard
    callbacks=[early_stop_callback, checkpoint_callback]
)

# πŸš‚ Start Training
trainer.fit(model, train_loader, val_loader)

[ ]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

# πŸš‚ Training
model = LitResNet18LSTM(
    num_classes=6,
    sequence_length=15,
    lstm_hidden_size=512,
    lstm_layers=2,
    dropout=0.2,
    lr=1e-3,
    pooling="mean"  # or "max"
)

# Initialize TensorBoardLogger
logger = TensorBoardLogger("tb_logs", name="LitResNet50LSTM")

# Initialize EarlyStopping and ModelCheckpoint callbacks
early_stop_callback = EarlyStopping(
    monitor="val_acc",
    patience=20,
    mode="max",
    verbose=True
)

checkpoint_callback = ModelCheckpoint(
    monitor="val_acc",
    mode="max",
    save_top_k=1,
    verbose=True,
    dirpath="checkpoints/",
    filename="LitResNet50LSTM-{epoch:02d}-{val_loss:.4f}"
)

# Initialize Trainer
trainer = Trainer(
    max_epochs=EPOCHS,
    accelerator="gpu",
    devices=1,
    precision=16,  # Use mixed precision for T4
    gradient_clip_val=0.5,
    log_every_n_steps=1,
    logger=logger,
    callbacks=[early_stop_callback, checkpoint_callback]
)

# Start Training
trainer.fit(model, train_loader, val_loader)

[ ]: