From 268ca35e776d840e67e54bf82fbe0d48a5266bc1 Mon Sep 17 00:00:00 2001 From: Jatin Mathur Date: Sat, 26 Aug 2023 19:39:56 -0700 Subject: [PATCH] done --- best-submission/train_unet.ipynb | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/best-submission/train_unet.ipynb b/best-submission/train_unet.ipynb index e623c19..1921ddd 100644 --- a/best-submission/train_unet.ipynb +++ b/best-submission/train_unet.ipynb @@ -12,8 +12,8 @@ "import xarray as xr\n", "\n", "# custom\n", - "import common.loss_utils as loss_utils\n", - "import common.climatehack_dataset as climatehack_dataset" + "from utils.loss import MS_SSIMLoss\n", + "from utils.data import ClimatehackDataset, CustomDataset" ] }, { @@ -74,8 +74,8 @@ "outputs": [], "source": [ "BATCH_SIZE = 32\n", - "train_ds = climatehack_dataset.ClimatehackDataset(dataset, random_state=7)\n", - "valid_ds = climatehack_dataset.ClimatehackDataset(dataset, random_state=3)\n", + "train_ds = ClimatehackDataset(dataset, random_state=7)\n", + "valid_ds = ClimatehackDataset(dataset, random_state=3)\n", "\n", "train_dl = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)\n", "valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=BATCH_SIZE)" @@ -739,7 +739,7 @@ "outputs": [], "source": [ "FORECAST = 24\n", - "criterion = loss_utils.MS_SSIMLoss(channels=FORECAST)" + "criterion = MS_SSIMLoss(channels=FORECAST)" ] }, {