Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,6 @@ __pycache__/

src/astropt/_version.py

# checkpoints
*.pt
wandb/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
[![PyPI Downloads](https://static.pepy.tech/badge/astropt)](https://pepy.tech/projects/astropt)
[![docs](https://app.readthedocs.org/projects/astropt/badge/)](https://astropt.readthedocs.io/)
[![License: AGPL-v3](https://img.shields.io/badge/License-AGPLv3-green.svg)](https://www.gnu.org/licenses/agpl-3.0.html)
[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-)
[![All Contributors](https://img.shields.io/badge/all_contributors-8-orange.svg?style=flat-square)](#contributors-)

[![ICML](https://img.shields.io/badge/AI4Science@ICML-2024---?logo=https%3A%2F%2Fneurips.cc%2Fstatic%2Fcore%2Fimg%2FNeurIPS-logo.svg&labelColor=68448B&color=b3b3b3)](https://openreview.net/forum?id=aOLuuLxqav)
[![arXiv](https://img.shields.io/badge/arXiv-2405.14930---?logo=arXiv&labelColor=b31b1b&color=grey)](https://arxiv.org/abs/2405.14930)
Expand Down
32 changes: 32 additions & 0 deletions affine_run.job
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# /bin/bash
# ----------------Parameters---------------------- #
#$ -S /bin/bash
#$ -pe mthread 32
#$ -q lTgpu.q
#$ -cwd
#$ -j y
#$ -N affine_job
#$ -o affine_job.log
#$ -m bea
#$ -M sogolsanjaripour@gmail.com
#$ -l gpu,ngpus=1
#
# ----------------Modules------------------------- #
#

module load tools/conda
start-conda
conda activate astropt

# ----------------Your Commands------------------- #
#
echo + `date` job $JOB_NAME started in $QUEUE with jobID=$JOB_ID on $HOSTNAME
echo + NSLOTS = $NSLOTS
#

cd astroPT
uv sync
uv run scripts/train.py --batch_size=32 --compile=False

#
echo = `date` job $JOB_NAME done
32 changes: 32 additions & 0 deletions aim_run.job
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# /bin/bash
# ----------------Parameters---------------------- #
#$ -S /bin/bash
#$ -pe mthread 32
#$ -q lTgpu.q
#$ -cwd
#$ -j y
#$ -N aim_job
#$ -o aim_job.log
#$ -m bea
#$ -M sogolsanjaripour@gmail.com
#$ -l gpu,ngpus=1
#
# ----------------Modules------------------------- #
#

module load tools/conda
start-conda
conda activate astropt

# ----------------Your Commands------------------- #
#
echo + `date` job $JOB_NAME started in $QUEUE with jobID=$JOB_ID on $HOSTNAME
echo + NSLOTS = $NSLOTS
#

cd astroPT
uv sync
uv run scripts/train.py --batch_size=32 --compile=False

#
echo = `date` job $JOB_NAME done
33 changes: 33 additions & 0 deletions jetformer_run.job
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# /bin/bash
# ----------------Parameters---------------------- #
#$ -S /bin/bash
#$ -pe mthread 32
#$ -q lTgpu.q
#$ -cwd
#$ -j y
#$ -N jetformer_job
#$ -o jetformer_job.log
#$ -m bea
#$ -M sogolsanjaripour@gmail.com
#$ -l gpu,ngpus=1
#
# ----------------Modules------------------------- #
#

module load tools/conda
start-conda
conda activate astropt

# ----------------Your Commands------------------- #
#
echo + `date` job $JOB_NAME started in $QUEUE with jobID=$JOB_ID on $HOSTNAME
echo + NSLOTS = $NSLOTS
#

cd astroPT
uv sync
#uv run scripts/train_jetformer.py --batch_size=32 --compile=False
#uv run torchrun --standalone --nproc_per_node=2 scripts/train_jetformer.py --batch_size=32 --compile=False
PYTHONUNBUFFERED=1 uv run scripts/train_jetformer.py --init_from=resume --batch_size=32 --compile=False
#
echo = `date` job $JOB_NAME done
63 changes: 63 additions & 0 deletions scripts/jetformer/download_galaxy_256x256.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# make_galaxy_pt_shards.py
# pip install datasets pillow torch

import os, math
import torch
from datasets import load_dataset
from PIL import Image

OUTDIR = "galaxy_pt_256x256_test" # folder for shards
TOTAL = 10_000 # how many samples
SHARD = 1024 # images per shard (last shard may be smaller)
SEED = 42
BUFFER = 50_000 # shuffle buffer

IMG_SIZE = 256 # match your model

def pil_to_chw_uint8(img: Image.Image):
img = img.convert("RGB")
t = torch.tensor(bytearray(img.tobytes()), dtype=torch.uint8)
t = t.view(IMG_SIZE, IMG_SIZE, 3).permute(2,0,1).contiguous() # CHW
return t

def main():
os.makedirs(OUTDIR, exist_ok=True)
ds = load_dataset("Smith42/galaxies", split="test", streaming=True)
ds = ds.shuffle(buffer_size=BUFFER, seed=SEED)

buf = []
saved = 0
shard_idx = 0

for ex in ds:
# strictly use image_crop
if "image_crop" not in ex or ex["image_crop"] is None:
continue
try:
t = pil_to_chw_uint8(ex["image_crop"])
except Exception:
continue

buf.append(t)
saved += 1

if len(buf) == SHARD:
x = torch.stack(buf, dim=0) # [N,3,256,256] uint8
torch.save({"images": x}, os.path.join(OUTDIR, f"shard_{shard_idx:04d}_test.pt"))
print(f"wrote shard {shard_idx:04d} with {x.size(0)} samples")
shard_idx += 1
buf.clear()

if saved >= TOTAL:
break

# flush remainder
if buf:
x = torch.stack(buf, dim=0)
torch.save({"images": x}, os.path.join(OUTDIR, f"shard_{shard_idx:04d}_test.pt"))
print(f"wrote shard {shard_idx:04d} with {x.size(0)} samples")

print(f"Done. Total saved: {saved} samples into {OUTDIR}/shard_*_test.pt")

if __name__ == "__main__":
main()
Loading