{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "0VDaCxqgOJ5r" }, "source": [ "# 7. 🧠 MISAW CNN-Temporal classification\n", "\n", "![Status](https://img.shields.io/static/v1.svg?label=Status&message=Finished&color=green)\n", "\n", "**Filled notebook:**\n", "[![View filled on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/bfortuno/Surgical-Phase-Recognition/blob/main/docs/tutorial_notebooks/tutorial7/misaw_cnn_temporal.ipynb)\n", "[![Open filled In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bfortuno/Surgical-Phase-Recognition/blob/main/docs/tutorial_notebooks/tutorial7/misaw_cnn_temporal.ipynb) \n", "**Author:** Benjamin I. Fortuno" ] }, { "cell_type": "markdown", "metadata": { "id": "c1807621" }, "source": [ "This notebook extends the previous approach of using CNN, adding temporal inference, to classify phases, steps, etc in data from the MISAW dataset, commonly used for computer vision and machine learning tasks. It covers:\n", "\n", "- Creating pytorch Dataset\n", "- Exploring options of CNN and Temporal inference\n", "- Visualize training and validation results" ] }, { "cell_type": "markdown", "metadata": { "id": "fcdce6df" }, "source": [ "### 🧰 Importing Required Libraries\n", "This section loads essential Python libraries like `os`, `cv2`, `glob`, and `pandas` which are needed for handling files, images, and data manipulation." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "WejVuxbi1kJB", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "8f82a567-55aa-43c5-b0bb-587ad32b8208" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.1/823.1 kB\u001b[0m \u001b[31m17.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m46.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m28.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m36.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m32.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m962.5/962.5 kB\u001b[0m \u001b[31m39.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "# πŸ“¦ Install dependencies\n", "!pip install pytorch-lightning -q\n", "!pip install torchmetrics -q\n", "!pip install pytorch-tcn -q" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Mxn02EuSrmdh" }, "outputs": [], "source": [ "# Standard library\n", "import os # For file and directory operations\n", "import glob # For finding files using wildcard patterns\n", "import json # For saving annotations in JSON format\n", "import shutil # For copying images\n", "\n", "# Third-party libraries\n", "import cv2\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "# PyTorch\n", "import torch\n", "from torch import nn\n", "from torch.utils.data import DataLoader, random_split\n", "\n", "# Torchvision\n", "from torchvision import datasets, transforms\n", "\n", "# PyTorch Lightning\n", "import pytorch_lightning as pl\n", "from pytorch_lightning import LightningModule, Trainer\n", "from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n", "from pytorch_lightning.loggers import TensorBoardLogger\n", "\n", "# Torchmetrics\n", "import torchmetrics" ] }, { "cell_type": "markdown", "metadata": { "id": "840ba669" }, "source": [ "### πŸ“₯ Downloading the Dataset\n", "Here we download the MISAW dataset in `.zip` format from a Dropbox link. This dataset contains videos and annotations for surgical workflow analysis." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "Pihxeq42rrJ4", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "0e9fccb5-a4cd-4573-d4a0-49ee48e95b61" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "--2025-06-03 13:04:18-- https://www.dropbox.com/scl/fi/psmlokrc5ms958ggqyv3u/MISAW.zip?rlkey=v91dz437npon5zz10olrbcqcd\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline/Cq7JOyLE9fpIYpamVYde6Q-VNjrNGMt_EeOcqzCuMzm9D1nUQ8N81p2meJXY-Y2LWLorNUUEQ0u44ki-V1MspFIjOLRVKXBC5fg1cCAoaiilT43d9d1B1yQFXTNAkR6O2f20vqrZxfG1K1zFAs8Pj6Ce/file# [following]\n", "--2025-06-03 13:04:19-- https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline/Cq7JOyLE9fpIYpamVYde6Q-VNjrNGMt_EeOcqzCuMzm9D1nUQ8N81p2meJXY-Y2LWLorNUUEQ0u44ki-V1MspFIjOLRVKXBC5fg1cCAoaiilT43d9d1B1yQFXTNAkR6O2f20vqrZxfG1K1zFAs8Pj6Ce/file\n", "Resolving uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com (uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f\n", "Connecting to uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com (uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: /cd/0/inline2/Cq4N1ejIi9DSsvK0Pk8YN4NoMw9mWfkEyrqIqQHlXvH6juKwgI_5ofLj-sfvGEV5jKXK7kas-17IbEQWHxnPk4SEyMODvNMzSov4ikzJSNzBTKSSem09xDEq_hjf8qlVmb6OrED4dpTLZzPkTynt3R0sNqEjXVHd6s25XH35E8mEqweZ1SWWfF8YXl896zpFWFolkuZb1FSLZIXrhX4yexdh0QeJLIgON0EgfyUDUYQ_vRSowymlMtdsMuWik8ArkSTlsiTfKHO2oL0kv-fThrOOm1q124Ni8q46mU_gm_JOX70rEHYoA0TA4bpMq7NsLZWo7h72ihNqSkcB8MPl0LBQfMK6cT3VRB6zbUv9nELarlv8roo0dPjvpyIl_wNgl_E/file [following]\n", "--2025-06-03 13:04:19-- https://uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com/cd/0/inline2/Cq4N1ejIi9DSsvK0Pk8YN4NoMw9mWfkEyrqIqQHlXvH6juKwgI_5ofLj-sfvGEV5jKXK7kas-17IbEQWHxnPk4SEyMODvNMzSov4ikzJSNzBTKSSem09xDEq_hjf8qlVmb6OrED4dpTLZzPkTynt3R0sNqEjXVHd6s25XH35E8mEqweZ1SWWfF8YXl896zpFWFolkuZb1FSLZIXrhX4yexdh0QeJLIgON0EgfyUDUYQ_vRSowymlMtdsMuWik8ArkSTlsiTfKHO2oL0kv-fThrOOm1q124Ni8q46mU_gm_JOX70rEHYoA0TA4bpMq7NsLZWo7h72ihNqSkcB8MPl0LBQfMK6cT3VRB6zbUv9nELarlv8roo0dPjvpyIl_wNgl_E/file\n", "Reusing existing connection to uc9947b1efa5bc8c7726c1d194d6.dl.dropboxusercontent.com:443.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 965450623 (921M) [application/zip]\n", "Saving to: β€˜MISAW.zip’\n", "\n", "MISAW.zip 100%[===================>] 920.72M 79.8MB/s in 13s \n", "\n", "2025-06-03 13:04:33 (70.2 MB/s) - β€˜MISAW.zip’ saved [965450623/965450623]\n", "\n" ] } ], "source": [ "!wget -O MISAW.zip https://www.dropbox.com/scl/fi/psmlokrc5ms958ggqyv3u/MISAW.zip?rlkey=v91dz437npon5zz10olrbcqcd&st=54qvf31m&dl=0" ] }, { "cell_type": "markdown", "metadata": { "id": "79da29e9" }, "source": [ "### πŸ“¦ Extracting the Dataset\n", "This cell unzips the downloaded file to access the raw data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "tfaLOhLOxqri" }, "outputs": [], "source": [ "!unzip -qq MISAW.zip" ] }, { "cell_type": "markdown", "metadata": { "id": "0b809ff4" }, "source": [ "## 🎞️ Frame Extraction from Videos\n", "This section reads the videos and extracts every N-th frame (controlled by `resample_rate`).\n", "Each video gets its own subdirectory of frames, organized as:\n", "```\n", "MISAW/\n", " └── train/\n", " └── Frames/\n", " └── /frame_0000.jpg\n", "```\n", "These images will be later paired with annotations." ] }, { "cell_type": "code", "source": [ "!sudo apt update\n", "!sudo apt install -y ffmpeg\n" ], "metadata": { "id": "bbQyrGDm45VG", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ebda07bc-46a7-4e8e-98f2-9b9d51f9a098" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[33m\r0% [Working]\u001b[0m\r \rGet:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]\n", "Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64 InRelease [1,581 B]\n", "Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease\n", "Get:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]\n", "Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]\n", "Get:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64 Packages [1,724 kB]\n", "Get:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]\n", "Get:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]\n", "Get:9 https://r2u.stat.illinois.edu/ubuntu jammy/main amd64 Packages [2,735 kB]\n", "Get:10 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 Packages [1,553 kB]\n", "Hit:11 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease\n", "Hit:12 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease\n", "Get:13 http://archive.ubuntu.com/ubuntu jammy-updates/restricted amd64 Packages [4,622 kB]\n", "Get:14 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [8,992 kB]\n", "Hit:15 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease\n", "Get:16 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 Packages [3,290 kB]\n", "Get:17 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [4,468 kB]\n", "Get:18 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [2,979 kB]\n", "Get:19 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,245 kB]\n", "Fetched 32.0 MB in 3s (11.6 MB/s)\n", "Reading package lists... Done\n", "Building dependency tree... Done\n", "Reading state information... Done\n", "38 packages can be upgraded. Run 'apt list --upgradable' to see them.\n", "\u001b[1;33mW: \u001b[0mSkipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)\u001b[0m\n", "Reading package lists... Done\n", "Building dependency tree... Done\n", "Reading state information... Done\n", "ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).\n", "0 upgraded, 0 newly installed, 0 to remove and 38 not upgraded.\n" ] } ] }, { "cell_type": "code", "source": [ "!nvidia-smi\n", "!ffmpeg -encoders | grep nvenc\n" ], "metadata": { "id": "ImHmRKsK5PkH", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "1c926453-cf73-4095-fa66-6b6be0abdf11" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Tue Jun 3 13:04:52 2025 \n", "+-----------------------------------------------------------------------------------------+\n", "| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |\n", "|-----------------------------------------+------------------------+----------------------+\n", "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|=========================================+========================+======================|\n", "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", "| N/A 36C P8 9W / 70W | 0MiB / 15360MiB | 0% Default |\n", "| | | N/A |\n", "+-----------------------------------------+------------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=========================================================================================|\n", "| No running processes found |\n", "+-----------------------------------------------------------------------------------------+\n", "ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers\n", " built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)\n", " configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-pocketsphinx --enable-librsvg --enable-libmfx --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\n", " libavutil 56. 70.100 / 56. 70.100\n", " libavcodec 58.134.100 / 58.134.100\n", " libavformat 58. 76.100 / 58. 76.100\n", " libavdevice 58. 13.100 / 58. 13.100\n", " libavfilter 7.110.100 / 7.110.100\n", " libswscale 5. 9.100 / 5. 9.100\n", " libswresample 3. 9.100 / 3. 9.100\n", " libpostproc 55. 9.100 / 55. 9.100\n", " V....D h264_nvenc NVIDIA NVENC H.264 encoder (codec h264)\n", " V..... nvenc NVIDIA NVENC H.264 encoder (codec h264)\n", " V..... nvenc_h264 NVIDIA NVENC H.264 encoder (codec h264)\n", " V..... nvenc_hevc NVIDIA NVENC hevc encoder (codec hevc)\n", " V....D hevc_nvenc NVIDIA NVENC hevc encoder (codec hevc)\n" ] } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HQ4dwnDizRPS", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "c8a6c6fc-7b9f-4953-f2bb-8e2da4709a04" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Extracting frames from 2_5.mp4...\n", "Frames saved to MISAW/train/Frames/2_5\n", "Extracting frames from 6_2.mp4...\n", "Frames saved to MISAW/train/Frames/6_2\n", "Extracting frames from 3_2.mp4...\n", "Frames saved to MISAW/train/Frames/3_2\n", "Extracting frames from 2_1.mp4...\n", "Frames saved to MISAW/train/Frames/2_1\n", "Extracting frames from 3_3.mp4...\n", "Frames saved to MISAW/train/Frames/3_3\n", "Extracting frames from 6_4.mp4...\n", "Frames saved to MISAW/train/Frames/6_4\n", "Extracting frames from 1_4.mp4...\n" ] } ], "source": [ "import os\n", "\n", "video_folder = 'MISAW/train/Video'\n", "frames_folder = 'MISAW/train/Frames'\n", "os.makedirs(frames_folder, exist_ok=True)\n", "\n", "resample_rate = 3 # Adjust as needed\n", "\n", "for video_file in os.listdir(video_folder):\n", " if video_file.endswith(('.mp4', '.avi', '.mov')):\n", " video_path = os.path.join(video_folder, video_file)\n", " video_name = os.path.splitext(video_file)[0]\n", " output_folder = os.path.join(frames_folder, video_name)\n", " os.makedirs(output_folder, exist_ok=True)\n", "\n", " # Use ffmpeg with GPU support to extract frames\n", " output_pattern = os.path.join(output_folder, \"frame_%04d.jpg\")\n", " command = f\"ffmpeg -hwaccel cuda -i '{video_path}' -vf 'select=not(mod(n\\\\,{resample_rate}))' -vsync vfr '{output_pattern}'\"\n", "\n", " print(f\"Extracting frames from {video_file}...\")\n", " os.system(command)\n", " print(f\"Frames saved to {output_folder}\")\n", "\n", "print(\"Done!\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "17133f20" }, "source": [ "## 🧾 Parsing Annotations and Building Frame-Label Pairs\n", "Here we parse the `.txt` annotation files (one per video) to extract surgical phases, then pair each extracted frame with a corresponding label.\n", "We also map textual phase labels to numeric IDs, which is important for training machine learning models." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1jMF8NXY0Z8F" }, "outputs": [], "source": [ "# Folder where the annotation .txt files are stored (one per video)\n", "annotation_folder = 'MISAW/train/Procedural decription'\n", "\n", "# Folder where extracted frames from videos are stored (in subfolders per video)\n", "frames_folder = 'MISAW/train/Frames'\n", "\n", "# This is the same rate you used to extract frames from videos (e.g., every 120th frame)\n", "# It must match or your labels won't align with the frames!\n", "\n", "# --- Collect all unique phases first ---\n", "\n", "# We'll store all unique phase labels (like \"Idle\", \"Suturing\", etc.) in this set\n", "all_phases = set()\n", "\n", "# Find all annotation files in the folder (e.g., 1_1_annotation.txt, 1_2_annotation.txt, etc.)\n", "annotation_files = sorted(glob.glob(os.path.join(annotation_folder, '*_annotation.txt')))\n", "\n", "# Loop through each annotation file and collect unique phase names\n", "for anno_file in annotation_files:\n", " df = pd.read_csv(anno_file, sep='\\t') # Load the .txt file into a DataFrame\n", " all_phases.update(df['Step'].unique()) # Add unique phases to the set\n", "\n", "# Create a dictionary to map each phase string to a unique integer ID\n", "# Useful for training machine learning models\n", "phase_to_id = {name: i for i, name in enumerate(sorted(all_phases))}\n", "\n", "# Create the inverse mapping (ID to phase name) for visualization later\n", "id_to_phase = {i: name for name, i in phase_to_id.items()}\n", "\n", "\n", "# --- Build (frame_path, label) pairs ---\n", "\n", "# This list will hold tuples like: (\"path/to/frame.jpg\", 2) β†’ (frame, label_id)\n", "all_data = []\n", "\n", "# Go through each annotation file (one per video)\n", "for anno_file in annotation_files:\n", " # Extract the video ID from the filename (e.g., \"1_1_annotation.txt\" β†’ \"1_1\")\n", " video_id = os.path.basename(anno_file).replace('_annotation.txt', '')\n", "\n", " # Construct the path to the corresponding frame folder\n", " frame_dir = os.path.join(frames_folder, video_id)\n", "\n", " # Read the annotation file into a DataFrame\n", " df = pd.read_csv(anno_file, sep='\\t')\n", "\n", " # Get the full list of phases (one per original video frame)\n", " phases = df['Step'].tolist()\n", "\n", " # Resample: only keep every N-th label (e.g., every 120th label)\n", " sampled_phases = phases[::resample_rate]\n", "\n", " # Convert phase strings to numeric labels using our earlier mapping\n", " sampled_ids = [phase_to_id[p] for p in sampled_phases]\n", "\n", " # Get the list of frame image paths, sorted so they match the order of labels\n", " frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))\n", "\n", " # Sanity check: if the number of frames doesn’t match the number of labels, skip this video\n", " if len(frame_paths) != len(sampled_ids):\n", " print(f\"⚠️ Mismatch for {video_id}: {len(frame_paths)} frames vs {len(sampled_ids)} labels\")\n", " continue\n", "\n", " # Add all (frame_path, label_id) pairs to our global list\n", " all_data.extend(zip(frame_paths, sampled_ids))\n", "\n", "\n", "# Final print to confirm total number of samples loaded\n", "print(f\"βœ… Total samples: {len(all_data)}\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pexU3fsGGdeh" }, "source": [ "### Visualize images and phases" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ijabIfAGcdj" }, "outputs": [], "source": [ "import random\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "from collections import defaultdict\n", "\n", "# CONFIGURATION\n", "\n", "# Number of sample images to display per phase (i.e., class/label)\n", "samples_per_phase = 3\n", "\n", "# Create a dictionary to group image paths by their label (phase ID)\n", "# defaultdict(list) automatically creates an empty list for new keys\n", "label_to_paths = defaultdict(list)\n", "\n", "# all_data is assumed to be a list of (frame_path, label_id) pairs\n", "# Here we organize all frame paths under their respective label IDs\n", "for path, label_id in all_data:\n", " label_to_paths[label_id].append(path)\n", "\n", "# Determine the number of unique phases (rows in the final plot)\n", "n_rows = len(label_to_paths)\n", "\n", "# Number of columns equals the number of samples we want to show per phase\n", "n_cols = samples_per_phase\n", "\n", "# Create a grid of subplots (n_rows x n_cols)\n", "# figsize sets the overall size of the figure (in inches)\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))\n", "\n", "# Add a main title above all the subplots\n", "fig.suptitle(\"Random Samples per Phase\", fontsize=18)\n", "\n", "# Get all label IDs in sorted order (for consistent row display)\n", "sorted_labels = sorted(label_to_paths.keys())\n", "\n", "# Loop over each row (i.e., each unique label/phase)\n", "for row, label_id in enumerate(sorted_labels):\n", " # Convert label ID back to its readable name using id_to_phase mapping\n", " label_name = id_to_phase[label_id]\n", "\n", " # Randomly sample a few images from this label\n", " # Ensure we don't try to sample more images than we actually have\n", " samples = random.sample(label_to_paths[label_id], min(samples_per_phase, len(label_to_paths[label_id])))\n", "\n", " # For each column (i.e., each sampled image for this label)\n", " for col in range(samples_per_phase):\n", " # Access the correct subplot (row x col)\n", " # If there's only one row, axes may be 1D\n", " ax = axes[row, col] if n_rows > 1 else axes[col]\n", "\n", " # Only plot if we have enough samples for this column\n", " if col < len(samples):\n", " # Open the image file using PIL\n", " img = Image.open(samples[col])\n", "\n", " # Display the image in the subplot\n", " ax.imshow(img)\n", "\n", " # Set the subplot title to the phase name\n", " ax.set_title(f\"{label_name}\", fontsize=10)\n", "\n", " # Remove axis ticks and labels for a cleaner look\n", " ax.axis('off')\n", "\n", "# Adjust layout to prevent overlaps\n", "plt.tight_layout()\n", "\n", "# Adjust top spacing to make room for the main title\n", "plt.subplots_adjust(top=0.95)\n", "\n", "# Display the final grid of images\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "0d376771" }, "source": [ "### πŸ“ Converting to COCO-style Format\n", "This section creates a new structure suitable for COCO-style datasets used in deep learning. It includes:\n", "```\n", "MISAW_coco/\n", " β”œβ”€β”€ Frames/\n", " β”‚ └── video1_frame_0000.jpg\n", " β”œβ”€β”€ annotations.json\n", " └── phase_to_id.json\n", "```\n", "This format is helpful for multi-label or object detection tasks." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "doH6AaZW_LQa" }, "outputs": [], "source": [ "# --- CONFIGURATION ---\n", "\n", "# Folder where annotation .txt files are stored\n", "annotation_folder = 'MISAW/train/Procedural decription'\n", "\n", "# Folder where frames are stored (e.g., MISAW/train/Frames/1_1/frame_0000.jpg)\n", "frames_folder = 'MISAW/train/Frames'\n", "\n", "# Output root folder for the new COCO-style dataset\n", "output_root = 'MISAW_coco'\n", "\n", "# Knot-related phases to group into a single \"Knot\" phase\n", "knot_phases = {'1 knot', '2 knot', '3 knot'}\n", "idle_phases = {'Idle', 'Idle Step'}\n", "\n", "# Inside this, all frames will be copied to one flat \"Frames\" folder\n", "frames_out_root = os.path.join(output_root, 'Frames')\n", "os.makedirs(frames_out_root, exist_ok=True) # Create the folder if it doesn't exist\n", "\n", "# --- STEP 1: COLLECT UNIQUE PHASES FROM ALL ANNOTATIONS ---\n", "\n", "# Set to store all phase names (e.g., Idle, Suturing, etc.)\n", "all_phases = set()\n", "\n", "# Loop through each annotation file and collect phase names\n", "for anno_file in glob.glob(f'{annotation_folder}/*_annotation.txt'):\n", " df = pd.read_csv(anno_file, sep='\\t') # Load tab-separated .txt file as a DataFrame\n", " df['Step'] = df['Step'].replace({p: \"knot\" for p in knot_phases})\n", " df['Step'] = df['Step'].replace({p: \"idle\" for p in idle_phases})\n", " all_phases.update(df['Step'].unique()) # Add unique phases to the set\n", "\n", "# Create dictionaries to map phases to integer IDs and back\n", "phase_to_id = {p: i for i, p in enumerate(sorted(all_phases))} # e.g., \"Suturing\" β†’ 2\n", "id_to_phase = {i: p for p, i in phase_to_id.items()} # e.g., 2 β†’ \"Suturing\"\n", "\n", "# --- STEP 2: BUILD JSON ENTRIES AND COPY FRAMES ---\n", "\n", "entries = [] # This will hold all annotation entries\n", "\n", "# Process each video annotation\n", "for anno_file in glob.glob(f'{annotation_folder}/*_annotation.txt'):\n", " video_id = os.path.basename(anno_file).replace('_annotation.txt', '') # e.g., \"1_1\"\n", "\n", " # Read annotation file\n", " df = pd.read_csv(anno_file, sep='\\t')\n", " df['Step'] = df['Step'].replace({p: \"knot\" for p in knot_phases})\n", " df['Step'] = df['Step'].replace({p: \"idle\" for p in idle_phases})\n", "\n", " # Resample phase labels to match saved frames (e.g., take every 120th label)\n", " phases = df['Step'].tolist()[::resample_rate]\n", "\n", " # Get list of resampled frame image paths\n", " frame_dir = os.path.join(frames_folder, video_id)\n", " frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))\n", "\n", " # Ensure the number of frames and labels match\n", " if len(phases) != len(frame_paths):\n", " print(f\"⚠️ Skipping {video_id}: mismatched {len(phases)} labels vs {len(frame_paths)} frames\")\n", " continue\n", "\n", " # Loop over each frame-label pair\n", " for i, (frame_path, phase) in enumerate(zip(frame_paths, phases)):\n", " # Create a new frame name (e.g., \"1_1_frame_0003.jpg\")\n", " new_frame_name = f\"{video_id}_frame_{i:04d}.jpg\"\n", "\n", " # Full destination path for the copied frame\n", " new_frame_path = os.path.join(frames_out_root, new_frame_name)\n", "\n", " # Copy frame to the output folder\n", " shutil.copy(frame_path, new_frame_path)\n", "\n", " # Add a dictionary entry for this frame in COCO-style format\n", " entries.append({\n", " \"video\": video_id,\n", " \"frame\": new_frame_name,\n", " \"path\": new_frame_path,\n", " \"label\": phase,\n", " \"label_id\": phase_to_id[phase]\n", " })\n", "\n", "# --- STEP 3: SAVE TO DISK ---\n", "\n", "# Save all frame annotations to a JSON file\n", "with open(os.path.join(output_root, 'annotations.json'), 'w') as f:\n", " json.dump(entries, f, indent=2)\n", "\n", "# Save the phase-to-ID mapping for future use\n", "with open(os.path.join(output_root, 'phase_to_id.json'), 'w') as f:\n", " json.dump(phase_to_id, f, indent=2)\n", "\n", "# Final confirmation\n", "print(f\"βœ… COCO-style structure created in {output_root}/ with {len(entries)} annotated frames.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "XPQP1C8JGkCw" }, "source": [ "## Dataset and Dataloader" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_u-X4zOx0V_K" }, "outputs": [], "source": [ "# from torch.utils.data import Dataset\n", "# from PIL import Image\n", "# import torch\n", "\n", "# class MISAWDataset(Dataset):\n", "# def __init__(self, annotations, sequence_length=15, transform=None):\n", "# self.annotations = annotations\n", "# self.sequence_length = sequence_length\n", "# self.transform = transform\n", "\n", "# # Ensure only complete sequences\n", "# self.num_sequences = len(self.annotations) // self.sequence_length\n", "\n", "# def __len__(self):\n", "# return self.num_sequences\n", "\n", "# def __getitem__(self, idx):\n", "# # Calculate the start and end index for this sequence\n", "# sequence_start = idx * self.sequence_length\n", "# sequence_end = sequence_start + self.sequence_length\n", "\n", "# # Collect the sequence of images\n", "# images = []\n", "# for i in range(sequence_start, sequence_end):\n", "# item = self.annotations[i]\n", "# image = Image.open(item['path']).convert('L') # Convert to grayscale\n", "# if self.transform:\n", "# image = self.transform(image)\n", "# images.append(image)\n", "\n", "# # Stack the images along the sequence dimension\n", "# images = torch.stack(images) # Shape: (sequence_length, C, H, W)\n", "\n", "# # Use the label of the last frame in the sequence as the target\n", "# label = self.annotations[sequence_end - 1]['label_id']\n", "\n", "# return images, label\n", "\n", "from torch.utils.data import Dataset\n", "from PIL import Image\n", "import torch\n", "\n", "class MISAWDataset(Dataset):\n", " def __init__(self, annotations, sequence_length=8, stride=1, transform=None):\n", " self.annotations = annotations\n", " self.sequence_length = sequence_length\n", " self.stride = stride\n", " self.transform = transform\n", "\n", " # Calculate the number of possible sequences with overlap\n", " self.valid_indices = list(range(0, len(self.annotations) - self.sequence_length + 1, self.stride))\n", "\n", " def __len__(self):\n", " # Return the number of possible overlapping sequences\n", " return len(self.valid_indices)\n", "\n", " def __getitem__(self, idx):\n", " # Calculate the start index for this sequence\n", " start_idx = self.valid_indices[idx]\n", " end_idx = start_idx + self.sequence_length\n", "\n", " # Collect the sequence of images\n", " images = []\n", " for i in range(start_idx, end_idx):\n", " item = self.annotations[i]\n", " image = Image.open(item['path']).convert('L') # Convert to grayscale\n", " if self.transform:\n", " image = self.transform(image)\n", " images.append(image)\n", "\n", " # Stack the images along the sequence dimension\n", " images = torch.stack(images) # Shape: (sequence_length, C, H, W)\n", "\n", " # Use the label of the last frame in the sequence as the target\n", " label = self.annotations[end_idx - 1]['label_id']\n", "\n", " return images, label" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aryRVmIJ0YxK" }, "outputs": [], "source": [ "import json\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms\n", "from collections import defaultdict\n", "\n", "# Load annotations from JSON\n", "with open('MISAW_coco/annotations.json') as f:\n", " annotations = json.load(f)\n", "\n", "# Create a list of video groups\n", "video_groups = defaultdict(list)\n", "for image_info in annotations:\n", " video_name = image_info[\"video\"]\n", " video_groups[video_name].append(image_info)\n", "\n", "# Split videos deterministically (80% train, 20% val)\n", "all_videos = sorted(video_groups.keys())\n", "split_idx = int(len(all_videos) * 0.8)\n", "train_videos = all_videos[:split_idx]\n", "val_videos = all_videos[split_idx:]\n", "\n", "# Collect dataset annotations\n", "train_annotations = [item for vid in train_videos for item in video_groups[vid]]\n", "val_annotations = [item for vid in val_videos for item in video_groups[vid]]\n", "\n", "print(f\"Train videos: {len(train_videos)}, Val videos: {len(val_videos)}\")\n", "print(f\"Train frames: {len(train_annotations)}, Val frames: {len(val_annotations)}\")" ] }, { "cell_type": "code", "source": [ "from torchvision import transforms\n", "\n", "# Image transformations for training (stronger augmentation)\n", "train_transform = transforms.Compose([\n", " transforms.Grayscale(), # Ensures image is 1-channel\n", " transforms.RandomResizedCrop((128, 128), scale=(0.8, 1.0)),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomVerticalFlip(),\n", " transforms.RandomRotation(15),\n", " transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),\n", " transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),\n", " transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5,), (0.5,))\n", "])\n", "\n", "# Image transformations for validation (minimal augmentation)\n", "val_transform = transforms.Compose([\n", " transforms.Grayscale(),\n", " transforms.Resize((128, 128)),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5,), (0.5,))\n", "])\n", "\n", "# Create the train and validation datasets\n", "SEQUENCE_LENGTH = 15\n", "train_dataset = MISAWDataset(train_annotations, sequence_length=SEQUENCE_LENGTH, transform=train_transform)\n", "val_dataset = MISAWDataset(val_annotations, sequence_length=SEQUENCE_LENGTH, transform=val_transform)\n", "\n", "print(f\"Train sequences: {len(train_dataset)}, Val sequences: {len(val_dataset)}\")\n" ], "metadata": { "id": "IMfg-0Yiu9LB" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from torch.utils.data import DataLoader\n", "\n", "BATCH_SIZE = 16\n", "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)\n", "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)\n", "\n", "# Quick sanity check\n", "for images, labels in train_loader:\n", " print(f\"Images shape: {images.shape}, Labels shape: {labels.shape}\")\n", " break" ], "metadata": { "id": "myb8YNCQvKKW" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TYCDSoHq0ZfQ" }, "outputs": [], "source": [ "# import matplotlib.pyplot as plt\n", "# import torch\n", "\n", "# # πŸ” Visualize examples of each digit from the training dataset\n", "# classes = list(range(8))\n", "# samples_per_class = {c: None for c in classes}\n", "\n", "# # Extract one full sequence for each class from the training set\n", "# for images, label in train_dataset:\n", "# # Convert the label to an integer if it is a tensor\n", "# if isinstance(label, torch.Tensor):\n", "# label = label.item()\n", "\n", "# # Use the label as a scalar, not a tensor\n", "# if samples_per_class[label] is None:\n", "# samples_per_class[label] = images # Shape: (sequence_length, 1, H, W)\n", "# if all(v is not None for v in samples_per_class.values()):\n", "# break\n", "\n", "# # Plot the sequences\n", "# fig, axes = plt.subplots(3, SEQUENCE_LENGTH, figsize=(20, 6)) # 3 classes, sequence_length frames each\n", "# for class_id in classes:\n", "# sequence = samples_per_class[class_id] # Shape: (sequence_length, 1, H, W)\n", "# for frame_idx in range(sequence.shape[0]):\n", "# ax = axes[class_id, frame_idx]\n", "# # Remove the channel dimension for grayscale images\n", "# ax.imshow(sequence[frame_idx].squeeze(0).cpu().numpy(), cmap=\"gray\")\n", "# ax.axis(\"off\")\n", "# if frame_idx == 0:\n", "# ax.set_title(f\"Class {class_id}\")\n", "\n", "# plt.show()\n" ] }, { "cell_type": "markdown", "source": [ "### ResNet50 + TCN" ], "metadata": { "id": "1ANPRINjyWNF" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RELwc50t1osH" }, "outputs": [], "source": [ "\n", "class LitResNet18TCN(LightningModule):\n", " def __init__(self, num_classes=5, sequence_length=15, tcn_channels=[256, 256, 256], kernel_size=4, dropout=0.5, lr=1e-5, pooling=\"mean\", freeze_layers=True):\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " # Load pretrained ResNet-18\n", " weights = ResNet18_Weights.DEFAULT\n", " self.model = resnet18(weights=weights)\n", "\n", " # Modify input layer to accept 1-channel (grayscale) input\n", " self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)\n", "\n", " # Optionally freeze layers\n", " if freeze_layers:\n", " for name, param in self.model.named_parameters():\n", " if not name.startswith(\"layer4\") and not name.startswith(\"fc\"):\n", " param.requires_grad = False\n", "\n", " # Extract feature dimension before FC\n", " self.feature_dim = self.model.fc.in_features\n", " self.model.fc = nn.Sequential(\n", " nn.Dropout(p=dropout),\n", " nn.Linear(self.feature_dim, self.feature_dim)\n", " )\n", "\n", " # TCN module\n", " self.tcn = TCN(\n", " num_inputs=self.feature_dim,\n", " num_channels=tcn_channels,\n", " kernel_size=kernel_size,\n", " dropout=dropout,\n", " input_shape='NCL',\n", " output_projection=num_classes,\n", " output_activation=None\n", " )\n", "\n", " self.accuracy = torchmetrics.Accuracy(task=\"multiclass\", num_classes=num_classes)\n", " self.lr = lr\n", " self.pooling = pooling\n", "\n", " def forward(self, x):\n", " B, T, C, H, W = x.shape\n", " x = x.view(B * T, C, H, W)\n", "\n", " features = self.model(x) # (B * T, feature_dim)\n", " features = features.view(B, T, -1).permute(0, 2, 1) # (B, feature_dim, T)\n", "\n", " logits = self.tcn(features) # (B, num_classes, T)\n", "\n", " if self.pooling == \"mean\":\n", " logits = logits.mean(dim=-1) # (B, num_classes)\n", " elif self.pooling == \"max\":\n", " logits, _ = logits.max(dim=-1) # (B, num_classes)\n", " else:\n", " raise ValueError(f\"Invalid pooling method: {self.pooling}\")\n", "\n", " return logits\n", "\n", " def configure_optimizers(self):\n", " return optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)\n", "\n", " def step(self, batch, stage=\"train\"):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = nn.CrossEntropyLoss()(y_hat, y)\n", " acc = self.accuracy(y_hat, y)\n", " self.log(f\"{stage}_loss\", loss, on_step=False, on_epoch=True)\n", " self.log(f\"{stage}_acc\", acc, on_step=False, on_epoch=True)\n", " return loss\n", "\n", " def training_step(self, batch, batch_idx):\n", " return self.step(batch, stage=\"train\")\n", "\n", " def validation_step(self, batch, batch_idx):\n", " return self.step(batch, stage=\"val\")\n", "\n", " def test_step(self, batch, batch_idx):\n", " return self.step(batch, stage=\"test\")\n" ] }, { "cell_type": "markdown", "source": [ "### ResNet50 + LSTM" ], "metadata": { "id": "QxNCnvZuybRI" } }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torchmetrics\n", "from pytorch_lightning import LightningModule\n", "from torchvision.models import resnet18, ResNet18_Weights\n", "\n", "class LitResNet18LSTM(LightningModule):\n", " def __init__(self, num_classes=5, sequence_length=15, lstm_hidden_size=512, lstm_layers=2, dropout=0.5, lr=1e-4, pooling=\"mean\", freeze_layers=True):\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " # Load pretrained ResNet-18\n", " weights = ResNet18_Weights.DEFAULT\n", " self.model = resnet18(weights=weights)\n", "\n", " # Modify input layer to accept 1-channel (grayscale) input\n", " self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)\n", "\n", " # Optionally freeze layers\n", " if freeze_layers:\n", " for name, param in self.model.named_parameters():\n", " if not name.startswith(\"layer4\") and not name.startswith(\"fc\"):\n", " param.requires_grad = False\n", "\n", " # Extract feature dimension before FC\n", " self.feature_dim = self.model.fc.in_features\n", " self.model.fc = nn.Sequential(\n", " nn.Dropout(p=dropout),\n", " nn.Linear(self.feature_dim, self.feature_dim)\n", " )\n", "\n", " # LSTM module\n", " self.lstm = nn.LSTM(\n", " input_size=self.feature_dim,\n", " hidden_size=lstm_hidden_size,\n", " num_layers=lstm_layers,\n", " batch_first=True,\n", " dropout=dropout,\n", " bidirectional=False\n", " )\n", "\n", " # Final classification layer\n", " self.fc = nn.Sequential(\n", " nn.Dropout(p=dropout),\n", " nn.Linear(lstm_hidden_size, num_classes)\n", " )\n", "\n", " self.accuracy = torchmetrics.Accuracy(task=\"multiclass\", num_classes=num_classes)\n", " self.lr = lr\n", " self.pooling = pooling\n", "\n", " def forward(self, x):\n", " B, T, C, H, W = x.shape\n", " x = x.view(B * T, C, H, W)\n", "\n", " features = self.model(x) # (B * T, feature_dim)\n", " features = features.view(B, T, -1)\n", "\n", " lstm_out, _ = self.lstm(features)\n", "\n", " if self.pooling == \"mean\":\n", " logits = lstm_out.mean(dim=1)\n", " elif self.pooling == \"max\":\n", " logits, _ = lstm_out.max(dim=1)\n", " else:\n", " raise ValueError(f\"Invalid pooling method: {self.pooling}\")\n", "\n", " logits = self.fc(logits)\n", " return logits\n", "\n", " def configure_optimizers(self):\n", " optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)\n", " scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n", " optimizer, mode='min', factor=0.5, patience=3, verbose=True\n", " )\n", " return {\n", " \"optimizer\": optimizer,\n", " \"lr_scheduler\": {\n", " \"scheduler\": scheduler,\n", " \"monitor\": \"val_loss\",\n", " \"interval\": \"epoch\",\n", " \"frequency\": 1\n", " }\n", " }\n", "\n", " def step(self, batch, stage=\"train\"):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = nn.CrossEntropyLoss()(y_hat, y)\n", " acc = self.accuracy(y_hat, y)\n", " self.log(f\"{stage}_loss\", loss, on_step=False, on_epoch=True)\n", " self.log(f\"{stage}_acc\", acc, on_step=False, on_epoch=True)\n", " return loss\n", "\n", " def training_step(self, batch, batch_idx):\n", " return self.step(batch, stage=\"train\")\n", "\n", " def validation_step(self, batch, batch_idx):\n", " return self.step(batch, stage=\"val\")\n", "\n", " def test_step(self, batch, batch_idx):\n", " return self.step(batch, stage=\"test\")\n" ], "metadata": { "id": "kxvsqhENk1K3" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Training" ], "metadata": { "id": "oaBeNEm-ygou" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SnSzZgANNr4R" }, "outputs": [], "source": [ "#create directory tb_logs\n", "!mkdir -p tb_logs/\n", "\n", "%load_ext tensorboard\n", "%tensorboard --logdir tb_logs/" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nwC-GWvZ10pH" }, "outputs": [], "source": [ "from pytorch_lightning import Trainer\n", "from pytorch_lightning.loggers import TensorBoardLogger\n", "from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n", "from pytorch_tcn import TCN\n", "\n", "# πŸš‚ Training Configuration\n", "MODEL_NAME = \"LitResNet50TCN\" # Change to LitResNet50LSTM if needed\n", "EPOCHS = 50\n", "CHECKPOINT_DIR = \"checkpoints/\"\n", "LOG_DIR = \"tb_logs\"\n", "\n", "# Select the model\n", "model = LitResNet18TCN(sequence_length=15) # or LitResNet50LSTM(sequence_length=8)\n", "\n", "# Initialize TensorBoard Logger\n", "logger = TensorBoardLogger(LOG_DIR, name=MODEL_NAME)\n", "\n", "# Initialize EarlyStopping and ModelCheckpoint callbacks\n", "early_stop_callback = EarlyStopping(\n", " monitor=\"val_acc\",\n", " patience=10,\n", " mode=\"max\",\n", " verbose=True\n", ")\n", "\n", "checkpoint_callback = ModelCheckpoint(\n", " monitor=\"val_acc\",\n", " mode=\"max\",\n", " save_top_k=1,\n", " verbose=True,\n", " dirpath=CHECKPOINT_DIR,\n", " filename=MODEL_NAME + \"-{epoch:02d}-{val_acc:.4f}\"\n", ")\n", "\n", "# Initialize the Trainer\n", "trainer = Trainer(\n", " max_epochs=EPOCHS,\n", " accelerator=\"auto\",\n", " devices=\"auto\",\n", " log_every_n_steps=1, # βœ… log at every step\n", " logger=logger, # Log to TensorBoard\n", " callbacks=[early_stop_callback, checkpoint_callback]\n", ")\n", "\n", "# πŸš‚ Start Training\n", "trainer.fit(model, train_loader, val_loader)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z7HZToP7Mzx7" }, "outputs": [], "source": [ "from pytorch_lightning import Trainer\n", "from pytorch_lightning.loggers import TensorBoardLogger\n", "from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n", "\n", "# πŸš‚ Training\n", "model = LitResNet18LSTM(\n", " num_classes=6,\n", " sequence_length=15,\n", " lstm_hidden_size=512,\n", " lstm_layers=2,\n", " dropout=0.2,\n", " lr=1e-3,\n", " pooling=\"mean\" # or \"max\"\n", ")\n", "\n", "# Initialize TensorBoardLogger\n", "logger = TensorBoardLogger(\"tb_logs\", name=\"LitResNet50LSTM\")\n", "\n", "# Initialize EarlyStopping and ModelCheckpoint callbacks\n", "early_stop_callback = EarlyStopping(\n", " monitor=\"val_acc\",\n", " patience=20,\n", " mode=\"max\",\n", " verbose=True\n", ")\n", "\n", "checkpoint_callback = ModelCheckpoint(\n", " monitor=\"val_acc\",\n", " mode=\"max\",\n", " save_top_k=1,\n", " verbose=True,\n", " dirpath=\"checkpoints/\",\n", " filename=\"LitResNet50LSTM-{epoch:02d}-{val_loss:.4f}\"\n", ")\n", "\n", "# Initialize Trainer\n", "trainer = Trainer(\n", " max_epochs=EPOCHS,\n", " accelerator=\"gpu\",\n", " devices=1,\n", " precision=16, # Use mixed precision for T4\n", " gradient_clip_val=0.5,\n", " log_every_n_steps=1,\n", " logger=logger,\n", " callbacks=[early_stop_callback, checkpoint_callback]\n", ")\n", "\n", "# Start Training\n", "trainer.fit(model, train_loader, val_loader)\n" ] }, { "cell_type": "code", "source": [], "metadata": { "id": "9inLpOEcTK7P" }, "execution_count": null, "outputs": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [], "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }