diff --git a/slide2vec/main.py b/slide2vec/main.py index 1a15914..70e539f 100644 --- a/slide2vec/main.py +++ b/slide2vec/main.py @@ -50,6 +50,15 @@ def log_progress(features_dir: Path, stop_event: threading.Event, log_interval: def run_tiling(root_dir, config_file, output_dir): print(f"Running tiling.py from {root_dir}...") + + # add root_dir to PYTHONPATH so hs2p module can be found + env = os.environ.copy() + root_dir_abs = os.path.abspath(root_dir) + if "PYTHONPATH" in env: + env["PYTHONPATH"] = f"{root_dir_abs}:{env['PYTHONPATH']}" + else: + env["PYTHONPATH"] = root_dir_abs + cmd = [ sys.executable, "hs2p/tiling.py", @@ -61,7 +70,7 @@ def run_tiling(root_dir, config_file, output_dir): "--skip-logging", "wandb.enable=false", # disable wandb to avoid dupliacte logging ] - proc = subprocess.run(cmd, cwd=root_dir) + proc = subprocess.run(cmd, cwd=root_dir, env=env) if proc.returncode != 0: print("Slide tiling failed. Exiting.") sys.exit(proc.returncode) diff --git a/slide2vec/models/models.py b/slide2vec/models/models.py index bb94ba5..ae4be7a 100644 --- a/slide2vec/models/models.py +++ b/slide2vec/models/models.py @@ -776,9 +776,9 @@ def forward(self, x): # x = [B, num_tiles, 3, 224, 224] B = x.size(0) x = rearrange(x, "b p c w h -> (b p) c w h") # [B*num_tiles, 3, 224, 224] - output = self.tile_encoder(x) # [B*num_tiles, features_dim] + tile_embedding = self.tile_encoder(x)["embedding"] # [B*num_tiles, features_dim] embedding = rearrange( - output, "(b p) f -> b p f", b=B + tile_embedding, "(b p) f -> b p f", b=B ) # [B, num_tiles, features_dim] output = {"embedding": embedding} return output