7. π§ MISAW CNN-Temporal classificationΒΆ
This notebook extends the previous approach of using CNN, adding temporal inference, to classify phases, steps, etc in data from the MISAW dataset, commonly used for computer vision and machine learning tasks. It covers:
Creating pytorch Dataset
Exploring options of CNN and Temporal inference
Visualize training and validation results
π§° Importing Required LibrariesΒΆ
This section loads essential Python libraries like os, cv2, glob, and pandas which are needed for handling files, images, and data manipulation.
[1]:
# π¦ Install dependencies
!pip install pytorch-lightning -q
!pip install torchmetrics -q
!pip install pytorch-tcn -q
ββββββββββββββββββββββββββββββββββββββββ 823.1/823.1 kB 17.1 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 363.4/363.4 MB 4.3 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 13.8/13.8 MB 46.0 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 24.6/24.6 MB 28.9 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 883.7/883.7 kB 36.0 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 664.8/664.8 MB 1.3 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 211.5/211.5 MB 6.3 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 56.3/56.3 MB 10.6 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 127.9/127.9 MB 9.3 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 207.5/207.5 MB 5.3 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 21.1/21.1 MB 32.3 MB/s eta 0:00:00
ββββββββββββββββββββββββββββββββββββββββ 962.5/962.5 kB 39.3 MB/s eta 0:00:00
[2]:
# Standard library
import os # For file and directory operations
import glob # For finding files using wildcard patterns
import json # For saving annotations in JSON format
import shutil # For copying images
# Third-party libraries
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
# Torchvision
from torchvision import datasets, transforms
# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
# Torchmetrics
import torchmetrics
π₯ Downloading the DatasetΒΆ
Here we download the MISAW dataset in .zip format from a Dropbox link. This dataset contains videos and annotations for surgical workflow analysis.
[3]:
!wget -O MISAW.zip https://www.dropbox.com/scl/fi/psmlokrc5ms958ggqyv3u/MISAW.zip?rlkey=v91dz437npon5zz10olrbcqcd&st=54qvf31m&dl=0
--2025-06-03 13:04:18-- https://www.dropbox.com/scl/fi/psmlokrc5ms958ggqyv3u/MISAW.zip?rlkey=v91dz437npon5zz10olrbcqcd
Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline/Cq7JOyLE9fpIYpamVYde6Q-VNjrNGMt_EeOcqzCuMzm9D1nUQ8N81p2meJXY-Y2LWLorNUUEQ0u44ki-V1MspFIjOLRVKXBC5fg1cCAoaiilT43d9d1B1yQFXTNAkR6O2f20vqrZxfG1K1zFAs8Pj6Ce/file# [following]
--2025-06-03 13:04:19-- https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline/Cq7JOyLE9fpIYpamVYde6Q-VNjrNGMt_EeOcqzCuMzm9D1nUQ8N81p2meJXY-Y2LWLorNUUEQ0u44ki-V1MspFIjOLRVKXBC5fg1cCAoaiilT43d9d1B1yQFXTNAkR6O2f20vqrZxfG1K1zFAs8Pj6Ce/file
Resolving uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com (uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f
Connecting to uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com (uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /cd/0/inline2/Cq4N1ejIi9DSsvK0Pk8YN4NoMw9mWfkEyrqIqQHlXvH6juKwgI_5ofLj-sfvGEV5jKXK7kas-17IbEQWHxnPk4SEyMODvNMzSov4ikzJSNzBTKSSem09xDEq_hjf8qlVmb6OrED4dpTLZzPkTynt3R0sNqEjXVHd6s25XH35E8mEqweZ1SWWfF8YXl896zpFWFolkuZb1FSLZIXrhX4yexdh0QeJLIgON0EgfyUDUYQ_vRSowymlMtdsMuWik8ArkSTlsiTfKHO2oL0kv-fThrOOm1q124Ni8q46mU_gm_JOX70rEHYoA0TA4bpMq7NsLZWo7h72ihNqSkcB8MPl0LBQfMK6cT3VRB6zbUv9nELarlv8roo0dPjvpyIl_wNgl_E/file [following]
--2025-06-03 13:04:19-- https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline2/Cq4N1ejIi9DSsvK0Pk8YN4NoMw9mWfkEyrqIqQHlXvH6juKwgI_5ofLj-sfvGEV5jKXK7kas-17IbEQWHxnPk4SEyMODvNMzSov4ikzJSNzBTKSSem09xDEq_hjf8qlVmb6OrED4dpTLZzPkTynt3R0sNqEjXVHd6s25XH35E8mEqweZ1SWWfF8YXl896zpFWFolkuZb1FSLZIXrhX4yexdh0QeJLIgON0EgfyUDUYQ_vRSowymlMtdsMuWik8ArkSTlsiTfKHO2oL0kv-fThrOOm1q124Ni8q46mU_gm_JOX70rEHYoA0TA4bpMq7NsLZWo7h72ihNqSkcB8MPl0LBQfMK6cT3VRB6zbUv9nELarlv8roo0dPjvpyIl_wNgl_E/file
Reusing existing connection to uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com:443.
HTTP request sent, awaiting response... 200 OK
Length: 965450623 (921M) [application/zip]
Saving to: βMISAW.zipβ
MISAW.zip 100%[===================>] 920.72M 79.8MB/s in 13s
2025-06-03 13:04:33 (70.2 MB/s) - βMISAW.zipβ saved [965450623/965450623]
π¦ Extracting the DatasetΒΆ
This cell unzips the downloaded file to access the raw data.
[4]:
!unzip -qq MISAW.zip
ποΈ Frame Extraction from VideosΒΆ
This section reads the videos and extracts every N-th frame (controlled by resample_rate). Each video gets its own subdirectory of frames, organized as:
MISAW/
βββ train/
βββ Frames/
βββ <video_id>/frame_0000.jpg
These images will be later paired with annotations.
[5]:
!sudo apt update
!sudo apt install -y ffmpeg
Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64 InRelease [1,581 B]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64 Packages [1,724 kB]
Get:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:9 https://r2u.stat.illinois.edu/ubuntu jammy/main amd64 Packages [2,735 kB]
Get:10 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 Packages [1,553 kB]
Hit:11 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:12 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:13 http://archive.ubuntu.com/ubuntu jammy-updates/restricted amd64 Packages [4,622 kB]
Get:14 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [8,992 kB]
Hit:15 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:16 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 Packages [3,290 kB]
Get:17 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [4,468 kB]
Get:18 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [2,979 kB]
Get:19 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,245 kB]
Fetched 32.0 MB in 3s (11.6 MB/s)
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
38 packages can be upgraded. Run 'apt list --upgradable' to see them.
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 38 not upgraded.
[6]:
!nvidia-smi
!ffmpeg -encoders | grep nvenc
Tue Jun 3 13:04:52 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 36C P8 9W / 70W | 0MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-pocketsphinx --enable-librsvg --enable-libmfx --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared
libavutil 56. 70.100 / 56. 70.100
libavcodec 58.134.100 / 58.134.100
libavformat 58. 76.100 / 58. 76.100
libavdevice 58. 13.100 / 58. 13.100
libavfilter 7.110.100 / 7.110.100
libswscale 5. 9.100 / 5. 9.100
libswresample 3. 9.100 / 3. 9.100
libpostproc 55. 9.100 / 55. 9.100
V....D h264_nvenc NVIDIA NVENC H.264 encoder (codec h264)
V..... nvenc NVIDIA NVENC H.264 encoder (codec h264)
V..... nvenc_h264 NVIDIA NVENC H.264 encoder (codec h264)
V..... nvenc_hevc NVIDIA NVENC hevc encoder (codec hevc)
V....D hevc_nvenc NVIDIA NVENC hevc encoder (codec hevc)
[ ]:
import os
video_folder = 'MISAW/train/Video'
frames_folder = 'MISAW/train/Frames'
os.makedirs(frames_folder, exist_ok=True)
resample_rate = 3 # Adjust as needed
for video_file in os.listdir(video_folder):
if video_file.endswith(('.mp4', '.avi', '.mov')):
video_path = os.path.join(video_folder, video_file)
video_name = os.path.splitext(video_file)[0]
output_folder = os.path.join(frames_folder, video_name)
os.makedirs(output_folder, exist_ok=True)
# Use ffmpeg with GPU support to extract frames
output_pattern = os.path.join(output_folder, "frame_%04d.jpg")
command = f"ffmpeg -hwaccel cuda -i '{video_path}' -vf 'select=not(mod(n\\,{resample_rate}))' -vsync vfr '{output_pattern}'"
print(f"Extracting frames from {video_file}...")
os.system(command)
print(f"Frames saved to {output_folder}")
print("Done!")
Extracting frames from 2_5.mp4...
Frames saved to MISAW/train/Frames/2_5
Extracting frames from 6_2.mp4...
Frames saved to MISAW/train/Frames/6_2
Extracting frames from 3_2.mp4...
Frames saved to MISAW/train/Frames/3_2
Extracting frames from 2_1.mp4...
Frames saved to MISAW/train/Frames/2_1
Extracting frames from 3_3.mp4...
Frames saved to MISAW/train/Frames/3_3
Extracting frames from 6_4.mp4...
Frames saved to MISAW/train/Frames/6_4
Extracting frames from 1_4.mp4...
π§Ύ Parsing Annotations and Building Frame-Label PairsΒΆ
Here we parse the .txt annotation files (one per video) to extract surgical phases, then pair each extracted frame with a corresponding label. We also map textual phase labels to numeric IDs, which is important for training machine learning models.
[ ]:
# Folder where the annotation .txt files are stored (one per video)
annotation_folder = 'MISAW/train/Procedural decription'
# Folder where extracted frames from videos are stored (in subfolders per video)
frames_folder = 'MISAW/train/Frames'
# This is the same rate you used to extract frames from videos (e.g., every 120th frame)
# It must match or your labels won't align with the frames!
# --- Collect all unique phases first ---
# We'll store all unique phase labels (like "Idle", "Suturing", etc.) in this set
all_phases = set()
# Find all annotation files in the folder (e.g., 1_1_annotation.txt, 1_2_annotation.txt, etc.)
annotation_files = sorted(glob.glob(os.path.join(annotation_folder, '*_annotation.txt')))
# Loop through each annotation file and collect unique phase names
for anno_file in annotation_files:
df = pd.read_csv(anno_file, sep='\t') # Load the .txt file into a DataFrame
all_phases.update(df['Step'].unique()) # Add unique phases to the set
# Create a dictionary to map each phase string to a unique integer ID
# Useful for training machine learning models
phase_to_id = {name: i for i, name in enumerate(sorted(all_phases))}
# Create the inverse mapping (ID to phase name) for visualization later
id_to_phase = {i: name for name, i in phase_to_id.items()}
# --- Build (frame_path, label) pairs ---
# This list will hold tuples like: ("path/to/frame.jpg", 2) β (frame, label_id)
all_data = []
# Go through each annotation file (one per video)
for anno_file in annotation_files:
# Extract the video ID from the filename (e.g., "1_1_annotation.txt" β "1_1")
video_id = os.path.basename(anno_file).replace('_annotation.txt', '')
# Construct the path to the corresponding frame folder
frame_dir = os.path.join(frames_folder, video_id)
# Read the annotation file into a DataFrame
df = pd.read_csv(anno_file, sep='\t')
# Get the full list of phases (one per original video frame)
phases = df['Step'].tolist()
# Resample: only keep every N-th label (e.g., every 120th label)
sampled_phases = phases[::resample_rate]
# Convert phase strings to numeric labels using our earlier mapping
sampled_ids = [phase_to_id[p] for p in sampled_phases]
# Get the list of frame image paths, sorted so they match the order of labels
frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))
# Sanity check: if the number of frames doesnβt match the number of labels, skip this video
if len(frame_paths) != len(sampled_ids):
print(f"β οΈ Mismatch for {video_id}: {len(frame_paths)} frames vs {len(sampled_ids)} labels")
continue
# Add all (frame_path, label_id) pairs to our global list
all_data.extend(zip(frame_paths, sampled_ids))
# Final print to confirm total number of samples loaded
print(f"β
Total samples: {len(all_data)}")
Visualize images and phasesΒΆ
[ ]:
import random
import matplotlib.pyplot as plt
from PIL import Image
from collections import defaultdict
# CONFIGURATION
# Number of sample images to display per phase (i.e., class/label)
samples_per_phase = 3
# Create a dictionary to group image paths by their label (phase ID)
# defaultdict(list) automatically creates an empty list for new keys
label_to_paths = defaultdict(list)
# all_data is assumed to be a list of (frame_path, label_id) pairs
# Here we organize all frame paths under their respective label IDs
for path, label_id in all_data:
label_to_paths[label_id].append(path)
# Determine the number of unique phases (rows in the final plot)
n_rows = len(label_to_paths)
# Number of columns equals the number of samples we want to show per phase
n_cols = samples_per_phase
# Create a grid of subplots (n_rows x n_cols)
# figsize sets the overall size of the figure (in inches)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
# Add a main title above all the subplots
fig.suptitle("Random Samples per Phase", fontsize=18)
# Get all label IDs in sorted order (for consistent row display)
sorted_labels = sorted(label_to_paths.keys())
# Loop over each row (i.e., each unique label/phase)
for row, label_id in enumerate(sorted_labels):
# Convert label ID back to its readable name using id_to_phase mapping
label_name = id_to_phase[label_id]
# Randomly sample a few images from this label
# Ensure we don't try to sample more images than we actually have
samples = random.sample(label_to_paths[label_id], min(samples_per_phase, len(label_to_paths[label_id])))
# For each column (i.e., each sampled image for this label)
for col in range(samples_per_phase):
# Access the correct subplot (row x col)
# If there's only one row, axes may be 1D
ax = axes[row, col] if n_rows > 1 else axes[col]
# Only plot if we have enough samples for this column
if col < len(samples):
# Open the image file using PIL
img = Image.open(samples[col])
# Display the image in the subplot
ax.imshow(img)
# Set the subplot title to the phase name
ax.set_title(f"{label_name}", fontsize=10)
# Remove axis ticks and labels for a cleaner look
ax.axis('off')
# Adjust layout to prevent overlaps
plt.tight_layout()
# Adjust top spacing to make room for the main title
plt.subplots_adjust(top=0.95)
# Display the final grid of images
plt.show()
π Converting to COCO-style FormatΒΆ
This section creates a new structure suitable for COCO-style datasets used in deep learning. It includes:
MISAW_coco/
βββ Frames/
β βββ video1_frame_0000.jpg
βββ annotations.json
βββ phase_to_id.json
This format is helpful for multi-label or object detection tasks.
[ ]:
# --- CONFIGURATION ---
# Folder where annotation .txt files are stored
annotation_folder = 'MISAW/train/Procedural decription'
# Folder where frames are stored (e.g., MISAW/train/Frames/1_1/frame_0000.jpg)
frames_folder = 'MISAW/train/Frames'
# Output root folder for the new COCO-style dataset
output_root = 'MISAW_coco'
# Knot-related phases to group into a single "Knot" phase
knot_phases = {'1 knot', '2 knot', '3 knot'}
idle_phases = {'Idle', 'Idle Step'}
# Inside this, all frames will be copied to one flat "Frames" folder
frames_out_root = os.path.join(output_root, 'Frames')
os.makedirs(frames_out_root, exist_ok=True) # Create the folder if it doesn't exist
# --- STEP 1: COLLECT UNIQUE PHASES FROM ALL ANNOTATIONS ---
# Set to store all phase names (e.g., Idle, Suturing, etc.)
all_phases = set()
# Loop through each annotation file and collect phase names
for anno_file in glob.glob(f'{annotation_folder}/*_annotation.txt'):
df = pd.read_csv(anno_file, sep='\t') # Load tab-separated .txt file as a DataFrame
df['Step'] = df['Step'].replace({p: "knot" for p in knot_phases})
df['Step'] = df['Step'].replace({p: "idle" for p in idle_phases})
all_phases.update(df['Step'].unique()) # Add unique phases to the set
# Create dictionaries to map phases to integer IDs and back
phase_to_id = {p: i for i, p in enumerate(sorted(all_phases))} # e.g., "Suturing" β 2
id_to_phase = {i: p for p, i in phase_to_id.items()} # e.g., 2 β "Suturing"
# --- STEP 2: BUILD JSON ENTRIES AND COPY FRAMES ---
entries = [] # This will hold all annotation entries
# Process each video annotation
for anno_file in glob.glob(f'{annotation_folder}/*_annotation.txt'):
video_id = os.path.basename(anno_file).replace('_annotation.txt', '') # e.g., "1_1"
# Read annotation file
df = pd.read_csv(anno_file, sep='\t')
df['Step'] = df['Step'].replace({p: "knot" for p in knot_phases})
df['Step'] = df['Step'].replace({p: "idle" for p in idle_phases})
# Resample phase labels to match saved frames (e.g., take every 120th label)
phases = df['Step'].tolist()[::resample_rate]
# Get list of resampled frame image paths
frame_dir = os.path.join(frames_folder, video_id)
frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))
# Ensure the number of frames and labels match
if len(phases) != len(frame_paths):
print(f"β οΈ Skipping {video_id}: mismatched {len(phases)} labels vs {len(frame_paths)} frames")
continue
# Loop over each frame-label pair
for i, (frame_path, phase) in enumerate(zip(frame_paths, phases)):
# Create a new frame name (e.g., "1_1_frame_0003.jpg")
new_frame_name = f"{video_id}_frame_{i:04d}.jpg"
# Full destination path for the copied frame
new_frame_path = os.path.join(frames_out_root, new_frame_name)
# Copy frame to the output folder
shutil.copy(frame_path, new_frame_path)
# Add a dictionary entry for this frame in COCO-style format
entries.append({
"video": video_id,
"frame": new_frame_name,
"path": new_frame_path,
"label": phase,
"label_id": phase_to_id[phase]
})
# --- STEP 3: SAVE TO DISK ---
# Save all frame annotations to a JSON file
with open(os.path.join(output_root, 'annotations.json'), 'w') as f:
json.dump(entries, f, indent=2)
# Save the phase-to-ID mapping for future use
with open(os.path.join(output_root, 'phase_to_id.json'), 'w') as f:
json.dump(phase_to_id, f, indent=2)
# Final confirmation
print(f"β
COCO-style structure created in {output_root}/ with {len(entries)} annotated frames.")
Dataset and DataloaderΒΆ
[ ]:
# from torch.utils.data import Dataset
# from PIL import Image
# import torch
# class MISAWDataset(Dataset):
# def __init__(self, annotations, sequence_length=15, transform=None):
# self.annotations = annotations
# self.sequence_length = sequence_length
# self.transform = transform
# # Ensure only complete sequences
# self.num_sequences = len(self.annotations) // self.sequence_length
# def __len__(self):
# return self.num_sequences
# def __getitem__(self, idx):
# # Calculate the start and end index for this sequence
# sequence_start = idx * self.sequence_length
# sequence_end = sequence_start + self.sequence_length
# # Collect the sequence of images
# images = []
# for i in range(sequence_start, sequence_end):
# item = self.annotations[i]
# image = Image.open(item['path']).convert('L') # Convert to grayscale
# if self.transform:
# image = self.transform(image)
# images.append(image)
# # Stack the images along the sequence dimension
# images = torch.stack(images) # Shape: (sequence_length, C, H, W)
# # Use the label of the last frame in the sequence as the target
# label = self.annotations[sequence_end - 1]['label_id']
# return images, label
from torch.utils.data import Dataset
from PIL import Image
import torch
class MISAWDataset(Dataset):
def __init__(self, annotations, sequence_length=8, stride=1, transform=None):
self.annotations = annotations
self.sequence_length = sequence_length
self.stride = stride
self.transform = transform
# Calculate the number of possible sequences with overlap
self.valid_indices = list(range(0, len(self.annotations) - self.sequence_length + 1, self.stride))
def __len__(self):
# Return the number of possible overlapping sequences
return len(self.valid_indices)
def __getitem__(self, idx):
# Calculate the start index for this sequence
start_idx = self.valid_indices[idx]
end_idx = start_idx + self.sequence_length
# Collect the sequence of images
images = []
for i in range(start_idx, end_idx):
item = self.annotations[i]
image = Image.open(item['path']).convert('L') # Convert to grayscale
if self.transform:
image = self.transform(image)
images.append(image)
# Stack the images along the sequence dimension
images = torch.stack(images) # Shape: (sequence_length, C, H, W)
# Use the label of the last frame in the sequence as the target
label = self.annotations[end_idx - 1]['label_id']
return images, label
[ ]:
import json
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from collections import defaultdict
# Load annotations from JSON
with open('MISAW_coco/annotations.json') as f:
annotations = json.load(f)
# Create a list of video groups
video_groups = defaultdict(list)
for image_info in annotations:
video_name = image_info["video"]
video_groups[video_name].append(image_info)
# Split videos deterministically (80% train, 20% val)
all_videos = sorted(video_groups.keys())
split_idx = int(len(all_videos) * 0.8)
train_videos = all_videos[:split_idx]
val_videos = all_videos[split_idx:]
# Collect dataset annotations
train_annotations = [item for vid in train_videos for item in video_groups[vid]]
val_annotations = [item for vid in val_videos for item in video_groups[vid]]
print(f"Train videos: {len(train_videos)}, Val videos: {len(val_videos)}")
print(f"Train frames: {len(train_annotations)}, Val frames: {len(val_annotations)}")
[ ]:
from torchvision import transforms
# Image transformations for training (stronger augmentation)
train_transform = transforms.Compose([
transforms.Grayscale(), # Ensures image is 1-channel
transforms.RandomResizedCrop((128, 128), scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Image transformations for validation (minimal augmentation)
val_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Create the train and validation datasets
SEQUENCE_LENGTH = 15
train_dataset = MISAWDataset(train_annotations, sequence_length=SEQUENCE_LENGTH, transform=train_transform)
val_dataset = MISAWDataset(val_annotations, sequence_length=SEQUENCE_LENGTH, transform=val_transform)
print(f"Train sequences: {len(train_dataset)}, Val sequences: {len(val_dataset)}")
[ ]:
from torch.utils.data import DataLoader
BATCH_SIZE = 16
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
# Quick sanity check
for images, labels in train_loader:
print(f"Images shape: {images.shape}, Labels shape: {labels.shape}")
break
[ ]:
# import matplotlib.pyplot as plt
# import torch
# # π Visualize examples of each digit from the training dataset
# classes = list(range(8))
# samples_per_class = {c: None for c in classes}
# # Extract one full sequence for each class from the training set
# for images, label in train_dataset:
# # Convert the label to an integer if it is a tensor
# if isinstance(label, torch.Tensor):
# label = label.item()
# # Use the label as a scalar, not a tensor
# if samples_per_class[label] is None:
# samples_per_class[label] = images # Shape: (sequence_length, 1, H, W)
# if all(v is not None for v in samples_per_class.values()):
# break
# # Plot the sequences
# fig, axes = plt.subplots(3, SEQUENCE_LENGTH, figsize=(20, 6)) # 3 classes, sequence_length frames each
# for class_id in classes:
# sequence = samples_per_class[class_id] # Shape: (sequence_length, 1, H, W)
# for frame_idx in range(sequence.shape[0]):
# ax = axes[class_id, frame_idx]
# # Remove the channel dimension for grayscale images
# ax.imshow(sequence[frame_idx].squeeze(0).cpu().numpy(), cmap="gray")
# ax.axis("off")
# if frame_idx == 0:
# ax.set_title(f"Class {class_id}")
# plt.show()
ResNet50 + TCNΒΆ
[ ]:
class LitResNet18TCN(LightningModule):
def __init__(self, num_classes=5, sequence_length=15, tcn_channels=[256, 256, 256], kernel_size=4, dropout=0.5, lr=1e-5, pooling="mean", freeze_layers=True):
super().__init__()
self.save_hyperparameters()
# Load pretrained ResNet-18
weights = ResNet18_Weights.DEFAULT
self.model = resnet18(weights=weights)
# Modify input layer to accept 1-channel (grayscale) input
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Optionally freeze layers
if freeze_layers:
for name, param in self.model.named_parameters():
if not name.startswith("layer4") and not name.startswith("fc"):
param.requires_grad = False
# Extract feature dimension before FC
self.feature_dim = self.model.fc.in_features
self.model.fc = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(self.feature_dim, self.feature_dim)
)
# TCN module
self.tcn = TCN(
num_inputs=self.feature_dim,
num_channels=tcn_channels,
kernel_size=kernel_size,
dropout=dropout,
input_shape='NCL',
output_projection=num_classes,
output_activation=None
)
self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.lr = lr
self.pooling = pooling
def forward(self, x):
B, T, C, H, W = x.shape
x = x.view(B * T, C, H, W)
features = self.model(x) # (B * T, feature_dim)
features = features.view(B, T, -1).permute(0, 2, 1) # (B, feature_dim, T)
logits = self.tcn(features) # (B, num_classes, T)
if self.pooling == "mean":
logits = logits.mean(dim=-1) # (B, num_classes)
elif self.pooling == "max":
logits, _ = logits.max(dim=-1) # (B, num_classes)
else:
raise ValueError(f"Invalid pooling method: {self.pooling}")
return logits
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)
def step(self, batch, stage="train"):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
acc = self.accuracy(y_hat, y)
self.log(f"{stage}_loss", loss, on_step=False, on_epoch=True)
self.log(f"{stage}_acc", acc, on_step=False, on_epoch=True)
return loss
def training_step(self, batch, batch_idx):
return self.step(batch, stage="train")
def validation_step(self, batch, batch_idx):
return self.step(batch, stage="val")
def test_step(self, batch, batch_idx):
return self.step(batch, stage="test")
ResNet50 + LSTMΒΆ
[ ]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from pytorch_lightning import LightningModule
from torchvision.models import resnet18, ResNet18_Weights
class LitResNet18LSTM(LightningModule):
def __init__(self, num_classes=5, sequence_length=15, lstm_hidden_size=512, lstm_layers=2, dropout=0.5, lr=1e-4, pooling="mean", freeze_layers=True):
super().__init__()
self.save_hyperparameters()
# Load pretrained ResNet-18
weights = ResNet18_Weights.DEFAULT
self.model = resnet18(weights=weights)
# Modify input layer to accept 1-channel (grayscale) input
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Optionally freeze layers
if freeze_layers:
for name, param in self.model.named_parameters():
if not name.startswith("layer4") and not name.startswith("fc"):
param.requires_grad = False
# Extract feature dimension before FC
self.feature_dim = self.model.fc.in_features
self.model.fc = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(self.feature_dim, self.feature_dim)
)
# LSTM module
self.lstm = nn.LSTM(
input_size=self.feature_dim,
hidden_size=lstm_hidden_size,
num_layers=lstm_layers,
batch_first=True,
dropout=dropout,
bidirectional=False
)
# Final classification layer
self.fc = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(lstm_hidden_size, num_classes)
)
self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.lr = lr
self.pooling = pooling
def forward(self, x):
B, T, C, H, W = x.shape
x = x.view(B * T, C, H, W)
features = self.model(x) # (B * T, feature_dim)
features = features.view(B, T, -1)
lstm_out, _ = self.lstm(features)
if self.pooling == "mean":
logits = lstm_out.mean(dim=1)
elif self.pooling == "max":
logits, _ = lstm_out.max(dim=1)
else:
raise ValueError(f"Invalid pooling method: {self.pooling}")
logits = self.fc(logits)
return logits
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1
}
}
def step(self, batch, stage="train"):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
acc = self.accuracy(y_hat, y)
self.log(f"{stage}_loss", loss, on_step=False, on_epoch=True)
self.log(f"{stage}_acc", acc, on_step=False, on_epoch=True)
return loss
def training_step(self, batch, batch_idx):
return self.step(batch, stage="train")
def validation_step(self, batch, batch_idx):
return self.step(batch, stage="val")
def test_step(self, batch, batch_idx):
return self.step(batch, stage="test")
TrainingΒΆ
[ ]:
#create directory tb_logs
!mkdir -p tb_logs/
%load_ext tensorboard
%tensorboard --logdir tb_logs/
[ ]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_tcn import TCN
# π Training Configuration
MODEL_NAME = "LitResNet50TCN" # Change to LitResNet50LSTM if needed
EPOCHS = 50
CHECKPOINT_DIR = "checkpoints/"
LOG_DIR = "tb_logs"
# Select the model
model = LitResNet18TCN(sequence_length=15) # or LitResNet50LSTM(sequence_length=8)
# Initialize TensorBoard Logger
logger = TensorBoardLogger(LOG_DIR, name=MODEL_NAME)
# Initialize EarlyStopping and ModelCheckpoint callbacks
early_stop_callback = EarlyStopping(
monitor="val_acc",
patience=10,
mode="max",
verbose=True
)
checkpoint_callback = ModelCheckpoint(
monitor="val_acc",
mode="max",
save_top_k=1,
verbose=True,
dirpath=CHECKPOINT_DIR,
filename=MODEL_NAME + "-{epoch:02d}-{val_acc:.4f}"
)
# Initialize the Trainer
trainer = Trainer(
max_epochs=EPOCHS,
accelerator="auto",
devices="auto",
log_every_n_steps=1, # β
log at every step
logger=logger, # Log to TensorBoard
callbacks=[early_stop_callback, checkpoint_callback]
)
# π Start Training
trainer.fit(model, train_loader, val_loader)
[ ]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
# π Training
model = LitResNet18LSTM(
num_classes=6,
sequence_length=15,
lstm_hidden_size=512,
lstm_layers=2,
dropout=0.2,
lr=1e-3,
pooling="mean" # or "max"
)
# Initialize TensorBoardLogger
logger = TensorBoardLogger("tb_logs", name="LitResNet50LSTM")
# Initialize EarlyStopping and ModelCheckpoint callbacks
early_stop_callback = EarlyStopping(
monitor="val_acc",
patience=20,
mode="max",
verbose=True
)
checkpoint_callback = ModelCheckpoint(
monitor="val_acc",
mode="max",
save_top_k=1,
verbose=True,
dirpath="checkpoints/",
filename="LitResNet50LSTM-{epoch:02d}-{val_loss:.4f}"
)
# Initialize Trainer
trainer = Trainer(
max_epochs=EPOCHS,
accelerator="gpu",
devices=1,
precision=16, # Use mixed precision for T4
gradient_clip_val=0.5,
log_every_n_steps=1,
logger=logger,
callbacks=[early_stop_callback, checkpoint_callback]
)
# Start Training
trainer.fit(model, train_loader, val_loader)
[ ]: