From f50cefa235e20e5b2990124a599787c0f74bfc4a Mon Sep 17 00:00:00 2001 From: Iain Date: Mon, 22 Dec 2025 15:20:31 -0500 Subject: [PATCH] Convolution interpolation explanation addition #2 --- .../Notebook06_Convolutions_PyTorch.ipynb | 176 ++++++++++++------ 1 file changed, 118 insertions(+), 58 deletions(-) diff --git a/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb b/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb index 82069d7..906eaf9 100644 --- a/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb +++ b/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb @@ -7,7 +7,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { @@ -51,9 +51,9 @@ "base_uri": "https://localhost:8080/" }, "id": "x2m10lWrHr_R", - "outputId": "f74daa9b-3d41-43b1-bbd0-4aad76043487" + "outputId": "a0fcf4ba-0837-49a9-b7da-4c9870b31550" }, - "execution_count": 3, + "execution_count": 1, "outputs": [ { "output_type": "stream", @@ -65,7 +65,7 @@ "remote: Counting objects: 100% (212/212), done.\u001b[K\n", "remote: Compressing objects: 100% (119/119), done.\u001b[K\n", "remote: Total 368 (delta 140), reused 134 (delta 93), pack-reused 156 (from 1)\u001b[K\n", - "Receiving objects: 100% (368/368), 114.60 MiB | 26.65 MiB/s, done.\n", + "Receiving objects: 100% (368/368), 114.60 MiB | 15.55 MiB/s, done.\n", "Resolving deltas: 100% (204/204), done.\n", "Updating files: 100% (120/120), done.\n", "done\n" @@ -75,24 +75,24 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PTq52qdZGnZu", - "outputId": "815ab8ed-641a-4dec-afd9-9ca04cdaa9c9" + "outputId": "b04f4060-fe1a-4143-e75a-07e1e0ceef40" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ - "" + "" ] }, "metadata": {}, - "execution_count": 5 + "execution_count": 2 } ], "source": [ @@ -146,14 +146,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 144 }, "id": "LCBzb32YGnZv", - "outputId": "f6b663d5-4d08-4142-cb08-454dcbfe5b67" + "outputId": "e9ecf48f-5d2f-4c83-bade-53815cbd9ae8" }, "outputs": [ { @@ -651,11 +651,11 @@ " radar_field: 4)\n", "Dimensions without coordinates: grid_row, grid_column, radar_height, radar_field\n", "Data variables:\n", - " radar_image_matrix (grid_row, grid_column, radar_height, radar_field) float32 197kB ..." + " radar_image_matrix (grid_row, grid_column, radar_height, radar_field) float32 197kB ..." ] }, "metadata": {}, - "execution_count": 7 + "execution_count": 3 } ], "source": [ @@ -681,14 +681,14 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "GMjoJBblGnZw", - "outputId": "5f62052e-99b4-4cb0-b6db-f33a08a48276" + "outputId": "13120921-9e2b-4874-9c01-6cbd3793baaf" }, "outputs": [ { @@ -699,7 +699,7 @@ ] }, "metadata": {}, - "execution_count": 9 + "execution_count": 4 }, { "output_type": "display_data", @@ -901,21 +901,96 @@ }, { "cell_type": "code", - "execution_count": 12, + "source": [ + "# %%\n", + "# Convert xarray data to numpy, then to a torch tensor\n", + "input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values\n", + "\n", + "# Step 1: Convert to Tensor\n", + "# The shape of input_data is likely (32, 32, 1) or (32, 32) depending on how xarray slices it.\n", + "# We ensure it starts as a float tensor.\n", + "print(f\"{input_data.shape}=\") # 32, 32, 1; (H, W, C)\n", + "input_tensor = torch.from_numpy(input_data).float()\n", + "\n", + "# Step 2: Ensure Correct Shape (Batch, Channel, Height, Width)\n", + "# If the input is (Height, Width, Channel), we need to permute it to (Channel, Height, Width)\n", + "if input_tensor.ndim == 3 and input_tensor.shape[-1] == 1:\n", + " input_tensor = input_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)\n", + "elif input_tensor.ndim == 2:\n", + " input_tensor = input_tensor.unsqueeze(0) # Add Channel dim if missing: (H, W) -> (C, H, W)\n", + "\n", + "# Add Batch Dimension: (C, H, W) -> (Batch, C, H, W)\n", + "input_tensor = input_tensor.unsqueeze(0)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zo_JAULelGBC", + "outputId": "313fa14f-3be3-47e5-cbc5-48e964a80505" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(32, 32, 1)=\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now our data (`input_tensor`) is in the correct shape: Batch, Channel, Height, Width.\n", + "\n", + "One last step before applying the convolution - we will resize `input_tensor`. This is simply for the sake of demonstrating the sharpening convolution. You generally would not do this in a ML model setting. The interpolation will smooth the original image a little, and the final sharpened convolution with look a little nicer. (Feel free to re-run this interpolation later with size=(32, 32) to see the convolution result below without the interpolation.)\n", + "\n", + "Note that we use the pytorch functional interface to interpolate the pytorch tensor object `input_tensor`. Again, we aren't performing any machine learning here. We're simply demonstrating how convolutions work." + ], + "metadata": { + "id": "cc5ABhj5nQ2e" + } + }, + { + "cell_type": "code", + "source": [ + "# interpolate from (1, 1, 32, 32), (1, 1, 36, 36)\n", + "more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)\n" + ], + "metadata": { + "id": "rnhhUlxnnLBT" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Finally we can define, configure, and apply our pytorch convolution operation, then plot the results." + ], + "metadata": { + "id": "nLzGMq4QlI4X" + } + }, + { + "cell_type": "code", + "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", - "height": 430 + "height": 450 }, "id": "7eM4lMhIGnZx", - "outputId": "d89b5397-f4fb-4991-c473-cba1b02123c7" + "outputId": "eb1fb700-7fc6-421b-a490-acb5a1355fed" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "(32, 32, 1)=\n" + "more_points.shape=torch.Size([1, 1, 36, 36]), convolution_result.shape=torch.Size([1, 1, 34, 34])\n" ] }, { @@ -935,44 +1010,21 @@ } ], "source": [ - "# %%\n", - "# Convert xarray data to numpy, then to a torch tensor\n", - "input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values\n", - "\n", - "# Step 1: Convert to Tensor\n", - "# The shape of input_data is likely (32, 32, 1) or (32, 32) depending on how xarray slices it.\n", - "# We ensure it starts as a float tensor.\n", - "print(f\"{input_data.shape}=\") # 32, 32, 1; (H, W, C)\n", - "input_tensor = torch.from_numpy(input_data).float()\n", - "\n", - "# Step 2: Ensure Correct Shape (Batch, Channel, Height, Width)\n", - "# If the input is (Height, Width, Channel), we need to permute it to (Channel, Height, Width)\n", - "if input_tensor.ndim == 3 and input_tensor.shape[-1] == 1:\n", - " input_tensor = input_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)\n", - "elif input_tensor.ndim == 2:\n", - " input_tensor = input_tensor.unsqueeze(0) # Add Channel dim if missing: (H, W) -> (C, H, W)\n", - "\n", - "# Add Batch Dimension: (C, H, W) -> (Batch, C, H, W)\n", - "input_tensor = input_tensor.unsqueeze(0)\n", - "\n", - "# Step 3: Resize using interpolate\n", - "# Now input_tensor is (1, 1, 32, 32), which matches the 2D spatial size expected by interpolate\n", - "more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)\n", "\n", "# Define the sharpen filter\n", "kernel_weights = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])\n", "\n", - "# Define conv with specific weights.\n", + "# Define conv\n", "conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)\n", "\n", - "# Set the weights manually\n", + "# Set the conv weights manually\n", "with torch.no_grad():\n", " # We reshape our 3x3 kernel to (1, 1, 3, 3)\n", " conv.weight.data = torch.from_numpy(kernel_weights).float().view(1, 1, 3, 3)\n", "\n", "# Run the data through\n", "convolution_result = conv(more_points)\n", - "\n", + "print(f\"{more_points.shape=}, {convolution_result.shape=}\")\n", "\n", "# Plotting\n", "fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))\n", @@ -1008,14 +1060,14 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 420 }, "id": "oiUu9nDyGnZy", - "outputId": "f22aa7f2-97c9-4b3a-cd16-a40aec3c5ee6" + "outputId": "8999b7db-e88a-4b93-ccd9-c1d5e1bb3fd0" }, "outputs": [ { @@ -1026,7 +1078,7 @@ ] }, "metadata": {}, - "execution_count": 15 + "execution_count": 7 }, { "output_type": "display_data", @@ -1123,14 +1175,14 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 420 }, "id": "qnfWwa9MGnZz", - "outputId": "6a9b113a-ee16-4a5d-ce93-6c65fa3b696e" + "outputId": "913a204a-b0f4-447e-adb1-b4004e133eb2" }, "outputs": [ { @@ -1141,7 +1193,7 @@ ] }, "metadata": {}, - "execution_count": 17 + "execution_count": 8 }, { "output_type": "display_data", @@ -1247,14 +1299,14 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 430 }, "id": "qLPaA52tGnZz", - "outputId": "0882cead-a86f-4fe8-b4e1-7b16a2cc7818" + "outputId": "65bfa86b-8896-4d38-e8de-76ad52b93a3c" }, "outputs": [ { @@ -1265,7 +1317,7 @@ ] }, "metadata": {}, - "execution_count": 18 + "execution_count": 9 }, { "output_type": "display_data", @@ -1329,16 +1381,23 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", - "height": 430 + "height": 449 }, "id": "lPNUoEYiGnZ0", - "outputId": "ccd5c1c8-3f2d-45c3-b6cc-c0ee0bd05dbe" + "outputId": "98067555-f1e0-493b-84c7-01494f3db8d9" }, "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "more_points.shape=torch.Size([1, 1, 36, 36]), res.shape=torch.Size([1, 1, 36, 36])\n" + ] + }, { "output_type": "execute_result", "data": { @@ -1347,7 +1406,7 @@ ] }, "metadata": {}, - "execution_count": 20 + "execution_count": 17 }, { "output_type": "display_data", @@ -1393,6 +1452,7 @@ "#run the data through\n", "res = conv(more_points)\n", "\n", + "print(f\"{more_points.shape=}, {res.shape=}\")\n", "\n", "fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))\n", "ax1.imshow(more_points.squeeze(), vmin=0, vmax=60, cmap='Spectral_r')\n",