Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 118 additions & 58 deletions colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/ssec/WAF_ML_Tutorial_Part2/blob/main/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/ssec/WAF_ML_Tutorial_Part2/blob/2-confusion-about-interpolating/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
Expand Down Expand Up @@ -51,9 +51,9 @@
"base_uri": "https://localhost:8080/"
},
"id": "x2m10lWrHr_R",
"outputId": "f74daa9b-3d41-43b1-bbd0-4aad76043487"
"outputId": "a0fcf4ba-0837-49a9-b7da-4c9870b31550"
},
"execution_count": 3,
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
Expand All @@ -65,7 +65,7 @@
"remote: Counting objects: 100% (212/212), done.\u001b[K\n",
"remote: Compressing objects: 100% (119/119), done.\u001b[K\n",
"remote: Total 368 (delta 140), reused 134 (delta 93), pack-reused 156 (from 1)\u001b[K\n",
"Receiving objects: 100% (368/368), 114.60 MiB | 26.65 MiB/s, done.\n",
"Receiving objects: 100% (368/368), 114.60 MiB | 15.55 MiB/s, done.\n",
"Resolving deltas: 100% (204/204), done.\n",
"Updating files: 100% (120/120), done.\n",
"done\n"
Expand All @@ -75,24 +75,24 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PTq52qdZGnZu",
"outputId": "815ab8ed-641a-4dec-afd9-9ca04cdaa9c9"
"outputId": "b04f4060-fe1a-4143-e75a-07e1e0ceef40"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fce206d96d0>"
"<torch._C.Generator at 0x7fbf50775a10>"
]
},
"metadata": {},
"execution_count": 5
"execution_count": 2
}
],
"source": [
Expand Down Expand Up @@ -146,14 +146,14 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 144
},
"id": "LCBzb32YGnZv",
"outputId": "f6b663d5-4d08-4142-cb08-454dcbfe5b67"
"outputId": "e9ecf48f-5d2f-4c83-bade-53815cbd9ae8"
},
"outputs": [
{
Expand Down Expand Up @@ -651,11 +651,11 @@
" radar_field: 4)\n",
"Dimensions without coordinates: grid_row, grid_column, radar_height, radar_field\n",
"Data variables:\n",
" radar_image_matrix (grid_row, grid_column, radar_height, radar_field) float32 197kB ...</pre><div class='xr-wrap' style='display:none'><div class='xr-header'><div class='xr-obj-type'>xarray.Dataset</div></div><ul class='xr-sections'><li class='xr-section-item'><input id='section-a7755d39-2484-4e97-ae1d-454290a50143' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-a7755d39-2484-4e97-ae1d-454290a50143' class='xr-section-summary' title='Expand/collapse section'>Dimensions:</label><div class='xr-section-inline-details'><ul class='xr-dim-list'><li><span>grid_row</span>: 32</li><li><span>grid_column</span>: 32</li><li><span>radar_height</span>: 12</li><li><span>radar_field</span>: 4</li></ul></div><div class='xr-section-details'></div></li><li class='xr-section-item'><input id='section-9a5d17a4-de47-407b-a07c-85494b34bafb' class='xr-section-summary-in' type='checkbox' checked><label for='section-9a5d17a4-de47-407b-a07c-85494b34bafb' class='xr-section-summary' >Data variables: <span>(1)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'><li class='xr-var-item'><div class='xr-var-name'><span>radar_image_matrix</span></div><div class='xr-var-dims'>(grid_row, grid_column, radar_height, radar_field)</div><div class='xr-var-dtype'>float32</div><div class='xr-var-preview xr-preview'>...</div><input id='attrs-59812b5c-822f-469f-a22f-77ed4ca76955' class='xr-var-attrs-in' type='checkbox' disabled><label for='attrs-59812b5c-822f-469f-a22f-77ed4ca76955' title='Show/Hide attributes'><svg class='icon xr-icon-file-text2'><use xlink:href='#icon-file-text2'></use></svg></label><input id='data-f28124eb-5b12-4d34-ba0a-4485b53f44d0' class='xr-var-data-in' type='checkbox'><label for='data-f28124eb-5b12-4d34-ba0a-4485b53f44d0' title='Show/Hide data repr'><svg class='icon xr-icon-database'><use xlink:href='#icon-database'></use></svg></label><div class='xr-var-attrs'><dl class='xr-attrs'></dl></div><div class='xr-var-data'><pre>[49152 values with dtype=float32]</pre></div></li></ul></div></li></ul></div></div>"
" radar_image_matrix (grid_row, grid_column, radar_height, radar_field) float32 197kB ...</pre><div class='xr-wrap' style='display:none'><div class='xr-header'><div class='xr-obj-type'>xarray.Dataset</div></div><ul class='xr-sections'><li class='xr-section-item'><input id='section-4183d7c7-e4db-4160-9040-d00243e6ea7f' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-4183d7c7-e4db-4160-9040-d00243e6ea7f' class='xr-section-summary' title='Expand/collapse section'>Dimensions:</label><div class='xr-section-inline-details'><ul class='xr-dim-list'><li><span>grid_row</span>: 32</li><li><span>grid_column</span>: 32</li><li><span>radar_height</span>: 12</li><li><span>radar_field</span>: 4</li></ul></div><div class='xr-section-details'></div></li><li class='xr-section-item'><input id='section-c8eb729d-53f9-4cb6-9fe8-715d4648f54f' class='xr-section-summary-in' type='checkbox' checked><label for='section-c8eb729d-53f9-4cb6-9fe8-715d4648f54f' class='xr-section-summary' >Data variables: <span>(1)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'><li class='xr-var-item'><div class='xr-var-name'><span>radar_image_matrix</span></div><div class='xr-var-dims'>(grid_row, grid_column, radar_height, radar_field)</div><div class='xr-var-dtype'>float32</div><div class='xr-var-preview xr-preview'>...</div><input id='attrs-fc79c9eb-fac4-45ba-8a24-e0cf9f9b1289' class='xr-var-attrs-in' type='checkbox' disabled><label for='attrs-fc79c9eb-fac4-45ba-8a24-e0cf9f9b1289' title='Show/Hide attributes'><svg class='icon xr-icon-file-text2'><use xlink:href='#icon-file-text2'></use></svg></label><input id='data-11a08a46-3ed6-4e67-a766-665b61804627' class='xr-var-data-in' type='checkbox'><label for='data-11a08a46-3ed6-4e67-a766-665b61804627' title='Show/Hide data repr'><svg class='icon xr-icon-database'><use xlink:href='#icon-database'></use></svg></label><div class='xr-var-attrs'><dl class='xr-attrs'></dl></div><div class='xr-var-data'><pre>[49152 values with dtype=float32]</pre></div></li></ul></div></li></ul></div></div>"
]
},
"metadata": {},
"execution_count": 7
"execution_count": 3
}
],
"source": [
Expand All @@ -681,14 +681,14 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "GMjoJBblGnZw",
"outputId": "5f62052e-99b4-4cb0-b6db-f33a08a48276"
"outputId": "13120921-9e2b-4874-9c01-6cbd3793baaf"
},
"outputs": [
{
Expand All @@ -699,7 +699,7 @@
]
},
"metadata": {},
"execution_count": 9
"execution_count": 4
},
{
"output_type": "display_data",
Expand Down Expand Up @@ -901,21 +901,96 @@
},
{
"cell_type": "code",
"execution_count": 12,
"source": [
"# %%\n",
"# Convert xarray data to numpy, then to a torch tensor\n",
"input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values\n",
"\n",
"# Step 1: Convert to Tensor\n",
"# The shape of input_data is likely (32, 32, 1) or (32, 32) depending on how xarray slices it.\n",
"# We ensure it starts as a float tensor.\n",
"print(f\"{input_data.shape}=\") # 32, 32, 1; (H, W, C)\n",
"input_tensor = torch.from_numpy(input_data).float()\n",
"\n",
"# Step 2: Ensure Correct Shape (Batch, Channel, Height, Width)\n",
"# If the input is (Height, Width, Channel), we need to permute it to (Channel, Height, Width)\n",
"if input_tensor.ndim == 3 and input_tensor.shape[-1] == 1:\n",
" input_tensor = input_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)\n",
"elif input_tensor.ndim == 2:\n",
" input_tensor = input_tensor.unsqueeze(0) # Add Channel dim if missing: (H, W) -> (C, H, W)\n",
"\n",
"# Add Batch Dimension: (C, H, W) -> (Batch, C, H, W)\n",
"input_tensor = input_tensor.unsqueeze(0)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zo_JAULelGBC",
"outputId": "313fa14f-3be3-47e5-cbc5-48e964a80505"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(32, 32, 1)=\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Now our data (`input_tensor`) is in the correct shape: Batch, Channel, Height, Width.\n",
"\n",
"One last step before applying the convolution - we will resize `input_tensor`. This is simply for the sake of demonstrating the sharpening convolution. You generally would not do this in a ML model setting. The interpolation will smooth the original image a little, and the final sharpened convolution with look a little nicer. (Feel free to re-run this interpolation later with size=(32, 32) to see the convolution result below without the interpolation.)\n",
"\n",
"Note that we use the pytorch functional interface to interpolate the pytorch tensor object `input_tensor`. Again, we aren't performing any machine learning here. We're simply demonstrating how convolutions work."
],
"metadata": {
"id": "cc5ABhj5nQ2e"
}
},
{
"cell_type": "code",
"source": [
"# interpolate from (1, 1, 32, 32), (1, 1, 36, 36)\n",
"more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)\n"
],
"metadata": {
"id": "rnhhUlxnnLBT"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Finally we can define, configure, and apply our pytorch convolution operation, then plot the results."
],
"metadata": {
"id": "nLzGMq4QlI4X"
}
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 430
"height": 450
},
"id": "7eM4lMhIGnZx",
"outputId": "d89b5397-f4fb-4991-c473-cba1b02123c7"
"outputId": "eb1fb700-7fc6-421b-a490-acb5a1355fed"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(32, 32, 1)=\n"
"more_points.shape=torch.Size([1, 1, 36, 36]), convolution_result.shape=torch.Size([1, 1, 34, 34])\n"
]
},
{
Expand All @@ -935,44 +1010,21 @@
}
],
"source": [
"# %%\n",
"# Convert xarray data to numpy, then to a torch tensor\n",
"input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values\n",
"\n",
"# Step 1: Convert to Tensor\n",
"# The shape of input_data is likely (32, 32, 1) or (32, 32) depending on how xarray slices it.\n",
"# We ensure it starts as a float tensor.\n",
"print(f\"{input_data.shape}=\") # 32, 32, 1; (H, W, C)\n",
"input_tensor = torch.from_numpy(input_data).float()\n",
"\n",
"# Step 2: Ensure Correct Shape (Batch, Channel, Height, Width)\n",
"# If the input is (Height, Width, Channel), we need to permute it to (Channel, Height, Width)\n",
"if input_tensor.ndim == 3 and input_tensor.shape[-1] == 1:\n",
" input_tensor = input_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)\n",
"elif input_tensor.ndim == 2:\n",
" input_tensor = input_tensor.unsqueeze(0) # Add Channel dim if missing: (H, W) -> (C, H, W)\n",
"\n",
"# Add Batch Dimension: (C, H, W) -> (Batch, C, H, W)\n",
"input_tensor = input_tensor.unsqueeze(0)\n",
"\n",
"# Step 3: Resize using interpolate\n",
"# Now input_tensor is (1, 1, 32, 32), which matches the 2D spatial size expected by interpolate\n",
"more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)\n",
"\n",
"# Define the sharpen filter\n",
"kernel_weights = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])\n",
"\n",
"# Define conv with specific weights.\n",
"# Define conv\n",
"conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)\n",
"\n",
"# Set the weights manually\n",
"# Set the conv weights manually\n",
"with torch.no_grad():\n",
" # We reshape our 3x3 kernel to (1, 1, 3, 3)\n",
" conv.weight.data = torch.from_numpy(kernel_weights).float().view(1, 1, 3, 3)\n",
"\n",
"# Run the data through\n",
"convolution_result = conv(more_points)\n",
"\n",
"print(f\"{more_points.shape=}, {convolution_result.shape=}\")\n",
"\n",
"# Plotting\n",
"fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))\n",
Expand Down Expand Up @@ -1008,14 +1060,14 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 420
},
"id": "oiUu9nDyGnZy",
"outputId": "f22aa7f2-97c9-4b3a-cd16-a40aec3c5ee6"
"outputId": "8999b7db-e88a-4b93-ccd9-c1d5e1bb3fd0"
},
"outputs": [
{
Expand All @@ -1026,7 +1078,7 @@
]
},
"metadata": {},
"execution_count": 15
"execution_count": 7
},
{
"output_type": "display_data",
Expand Down Expand Up @@ -1123,14 +1175,14 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 420
},
"id": "qnfWwa9MGnZz",
"outputId": "6a9b113a-ee16-4a5d-ce93-6c65fa3b696e"
"outputId": "913a204a-b0f4-447e-adb1-b4004e133eb2"
},
"outputs": [
{
Expand All @@ -1141,7 +1193,7 @@
]
},
"metadata": {},
"execution_count": 17
"execution_count": 8
},
{
"output_type": "display_data",
Expand Down Expand Up @@ -1247,14 +1299,14 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 430
},
"id": "qLPaA52tGnZz",
"outputId": "0882cead-a86f-4fe8-b4e1-7b16a2cc7818"
"outputId": "65bfa86b-8896-4d38-e8de-76ad52b93a3c"
},
"outputs": [
{
Expand All @@ -1265,7 +1317,7 @@
]
},
"metadata": {},
"execution_count": 18
"execution_count": 9
},
{
"output_type": "display_data",
Expand Down Expand Up @@ -1329,16 +1381,23 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 430
"height": 449
},
"id": "lPNUoEYiGnZ0",
"outputId": "ccd5c1c8-3f2d-45c3-b6cc-c0ee0bd05dbe"
"outputId": "98067555-f1e0-493b-84c7-01494f3db8d9"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"more_points.shape=torch.Size([1, 1, 36, 36]), res.shape=torch.Size([1, 1, 36, 36])\n"
]
},
{
"output_type": "execute_result",
"data": {
Expand All @@ -1347,7 +1406,7 @@
]
},
"metadata": {},
"execution_count": 20
"execution_count": 17
},
{
"output_type": "display_data",
Expand Down Expand Up @@ -1393,6 +1452,7 @@
"#run the data through\n",
"res = conv(more_points)\n",
"\n",
"print(f\"{more_points.shape=}, {res.shape=}\")\n",
"\n",
"fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))\n",
"ax1.imshow(more_points.squeeze(), vmin=0, vmax=60, cmap='Spectral_r')\n",
Expand Down
Loading