Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
ddea611
Adding AR1 and VAR1 Data
HariDaCoder Sep 24, 2025
39b86fe
change init_state
HariDaCoder Oct 1, 2025
c581610
Add Ridge and other relevant functions
HoangTimothy Oct 4, 2025
3c427cf
Add ridge and other relevant functions
HoangTimothy Oct 4, 2025
767436f
Update fetures
HariDaCoder Oct 6, 2025
0b54ebe
Add ar2 and vr2, draw figure 3
HoangTimothy Oct 9, 2025
c247857
adding littles
HariDaCoder Oct 13, 2025
4216fb8
Add ridge and other relevant functions
HoangTimothy Oct 4, 2025
ab49e97
NST
HoangTimothy Oct 13, 2025
4288f99
NST1
HoangTimothy Oct 13, 2025
ded21a2
figure_3_and_4
HoangTimothy Oct 13, 2025
86f52fc
chore(eval): add numpy import for function visualizations
HoangTimothy Oct 13, 2025
76d8a32
update 14/10
HariDaCoder Oct 14, 2025
2d654be
exponential error
HoangTimothy Oct 22, 2025
3317f13
update main
HoangTimothy Oct 22, 2025
e4a8357
train_steps = 50k1
HoangTimothy Oct 22, 2025
953aaf0
Update bla bla bla
HariDaCoder Oct 22, 2025
e03541c
Update tasks.py
HariDaCoder Oct 28, 2025
70e269c
current status
HoangTimothy Oct 31, 2025
319e9e9
ad vr1
HoangTimothy Oct 31, 2025
6334a5d
Plot
HoangTimothy Nov 1, 2025
ba71f75
models update metrics
HoangTimothy Nov 3, 2025
433b64f
return original code
HoangTimothy Nov 3, 2025
0d6accd
run_all
HoangTimothy Nov 3, 2025
d0733e3
Change toy
HoangTimothy Nov 4, 2025
0a619c1
run
HoangTimothy Nov 4, 2025
f286bbd
add rayleigh distribution
HoangTimothy Nov 4, 2025
0b9cde8
ridge 0.5
HoangTimothy Nov 5, 2025
e1a6bde
ridge 0.5 for sparse
HoangTimothy Nov 5, 2025
c4f0e1b
sparse
HoangTimothy Nov 6, 2025
5ce5e8a
sampler sparse
HoangTimothy Nov 6, 2025
6e953e3
eror
HoangTimothy Nov 6, 2025
2668a30
update total
HoangTimothy Nov 7, 2025
4ea0c67
Add run_all.py
HoangTimothy Nov 9, 2025
7ee9d40
fix toy
HoangTimothy Nov 10, 2025
a79c96e
add information toy
HoangTimothy Nov 11, 2025
7eccd55
fix to run
HoangTimothy Nov 11, 2025
48f5119
fix code
HoangTimothy Nov 11, 2025
1644465
figure 3 and 4
HoangTimothy Nov 11, 2025
9cfcdf3
update all
HoangTimothy Nov 11, 2025
f991000
update train.py
HoangTimothy Nov 11, 2025
77d5df2
update task.py
HoangTimothy Nov 11, 2025
3492610
update allll
HoangTimothy Nov 11, 2025
914d5fd
update alllll
HoangTimothy Nov 11, 2025
6b60d3b
ExponentialWeighted
HoangTimothy Nov 12, 2025
3ff16c4
expw
HoangTimothy Nov 14, 2025
b594228
laplace w
HoangTimothy Nov 15, 2025
49eea98
add laplace
HoangTimothy Nov 15, 2025
960c730
500k ready
HoangTimothy Nov 17, 2025
ff29deb
500k readyy
HoangTimothy Nov 17, 2025
fc1a423
caa
HoangTimothy Nov 17, 2025
fbb966c
runall
HoangTimothy Nov 18, 2025
8346ed6
add exponential, laplace, beta, gamma place
HoangTimothy Nov 21, 2025
cf10486
add exponential, laplace, beta, gamma place..
HoangTimothy Nov 21, 2025
c798ab5
wlaplacexexponentialzpoisson
HoangTimothy Nov 21, 2025
c4e9293
fix code w
HoangTimothy Nov 21, 2025
dac99b0
cauchy fix
HoangTimothy Nov 23, 2025
20c27b6
fix cauchy
HoangTimothy Nov 23, 2025
eac7d5e
vectorize
HoangTimothy Nov 23, 2025
484c4ff
changedau
HoangTimothy Nov 26, 2025
59b8f36
changed
HoangTimothy Nov 26, 2025
89fba38
fuck
HoangTimothy Nov 26, 2025
44d059a
glob
HoangTimothy Nov 26, 2025
2197d67
ab
HoangTimothy Nov 26, 2025
098b2fd
ab
HoangTimothy Nov 26, 2025
f1358f5
ab
HoangTimothy Nov 26, 2025
708ebb6
ab
HoangTimothy Nov 26, 2025
77f3bc1
ab
HoangTimothy Nov 26, 2025
c879888
add sampler full
HoangTimothy Nov 29, 2025
e8ad91b
add sampler full
HoangTimothy Nov 29, 2025
e978a52
add sampler full
HoangTimothy Nov 29, 2025
8e9011f
add sampler full
HoangTimothy Nov 29, 2025
b508b0c
add sampler full
HoangTimothy Nov 29, 2025
2f874fb
update noiselr
HoangTimothy Dec 3, 2025
aae530f
update wxe
HoangTimothy Dec 3, 2025
211cfa3
update
HoangTimothy Dec 3, 2025
511f4cd
update wx
HoangTimothy Dec 4, 2025
dd18079
Merge branch 'main' into hoang-update
HariDaCoder Dec 6, 2025
9dc1ee0
add all x, y to GPU
HoangTimothy Dec 11, 2025
c6b519e
fix error
HoangTimothy Dec 11, 2025
a176237
fix error
HoangTimothy Dec 11, 2025
6311528
fix conflict device
HoangTimothy Dec 15, 2025
0d66e48
FIX conflict device
HoangTimothy Dec 15, 2025
5bda600
FIX conflict device
HoangTimothy Dec 15, 2025
c0edda9
FIX: Remove stdict wrapper from task_kwargs and data_kwargs to preven…
HoangTimothy Dec 15, 2025
7d271f9
FIX: Merge stdict with nullable for task_kwargs and data_kwargs
HoangTimothy Dec 15, 2025
a9f6d50
DEBUG: Add print statements to track where training gets stuck
HoangTimothy Dec 15, 2025
51d3c25
Add cpu_only option to schema
HoangTimothy Dec 15, 2025
ef3488b
fix error
HoangTimothy Dec 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
- pip=21.2.4
- python=3.8.12
- python=3.9
- pytorch=1.11.0
- pip:
- jupyter==1.0.0
Expand Down
15 changes: 15 additions & 0 deletions src/conf/ar2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defaults:
- base

training:
data: ar2
data_kwargs:
rho1: 0.5
rho2: 0.3
noise_std: 0.1

curriculum:
points:
start: 5
end: 40
step: 5
2 changes: 1 addition & 1 deletion src/conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ model:
training:
data: gaussian
task_kwargs: {}
batch_size: 64
batch_size: 256
learning_rate: 0.0001
save_every_steps: 1000
keep_every_steps: 100000
Expand Down
37 changes: 37 additions & 0 deletions src/conf/case1_w_sparse_uniform_x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
inherit:
- base.yaml

model:
n_dims: 20
n_positions: 101

training:
task: sparse_regression_killer
task_kwargs:
k_sparse: 2
scale: 1.0
data: uniform
data_kwargs: {}
curriculum:
dims:
start: 5
end: 20
inc: 1
interval: 2000
points:
start: 11
end: 41
inc: 2
interval: 2000
batch_size: 64
learning_rate: 0.0001
train_steps: 500001

out_dir: ../models/sparse_regression_killer

wandb:
project: "in-context-training"
entity: "hai-trinh220970-ho-chi-minh-city-university-of-technology"
name: "case1_sparse_regression"
notes: "Case 1: Sparse Regression - only k=2 dims non-zero - Ridge Trap"
log_every_steps: 100
38 changes: 38 additions & 0 deletions src/conf/case2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
inherit:
- base.yaml

model:
n_dims: 20
n_positions: 101

training:
task: heavy_tail_noise_killer
task_kwargs:
noise_type: "t-student"
df: 3.0
noise_scale: 0.5
data: gaussian
data_kwargs: {}
curriculum:
dims:
start: 5
end: 20
inc: 1
interval: 2000
points:
start: 11
end: 41
inc: 2
interval: 2000
batch_size: 64
learning_rate: 0.0001
train_steps: 500001

out_dir: ../models/heavy_tail_noise_killer

wandb:
project: "in-context-training"
entity: "hai-trinh220970-ho-chi-minh-city-university-of-technology"
name: "case2_heavy_tail_t_student"
notes: "Case 2: Heavy-tail noise (t-student df=3, scale=0.5) - OLS Enemy"
log_every_steps: 100
36 changes: 36 additions & 0 deletions src/conf/case4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
inherit:
- base.yaml

model:
n_dims: 20
n_positions: 101

training:
task: mixture_tasks_killer
task_kwargs:
scale: 1.0
data: gaussian
data_kwargs: {}
curriculum:
dims:
start: 5
end: 20
inc: 1
interval: 2000
points:
start: 11
end: 41
inc: 2
interval: 2000
batch_size: 64
learning_rate: 0.0001
train_steps: 500001

out_dir: ../models/mixture_tasks_killer

wandb:
project: "in-context-training"
entity: "hai-trinh220970-ho-chi-minh-city-university-of-technology"
name: "case4_mixture_tasks"
notes: "Case 4: Mixture of Tasks - 50% y=w^T x, 50% y=-w^T x - Averaging Death"
log_every_steps: 100
38 changes: 38 additions & 0 deletions src/conf/case5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
inherit:
- base.yaml

model:
n_dims: 20
n_positions: 101

training:
task: transfer_tradeoff_task
task_kwargs:
prior_type: "mixture_gaussian"
mixture_std: 2.0
scale: 1.0
data: gaussian
data_kwargs: {}
curriculum:
dims:
start: 20
end: 20
inc: 1
interval: 2000
points:
start: 5
end: 10
inc: 1
interval: 2000
batch_size: 64
learning_rate: 0.0001
train_steps: 500001

out_dir: ../models/transfer_tradeoff_task

wandb:
project: "in-context-training"
entity: "hai-trinh220970-ho-chi-minh-city-university-of-technology"
name: "case5_transfer_tradeoff"
notes: "Case 5: Transfer Tradeoff - p×N experiment (Wakayama) - Mixture Gaussian prior"
log_every_steps: 100
41 changes: 41 additions & 0 deletions src/conf/case_3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
inherit:
- base.yaml

model:
n_dims: 20
n_positions: 101

training:
task: bounded_support_killer
task_kwargs:
rate: 1.0
scale: 1.0
# Use positive-only input distribution
data: uniform
data_kwargs: {}
# data: exponential
# data_kwargs:
# rate: 1.0
curriculum:
dims:
start: 5
end: 20
inc: 1
interval: 2000
points:
start: 11
end: 41
inc: 2
interval: 2000
batch_size: 64
learning_rate: 0.0001
train_steps: 500001

out_dir: ../models/bounded_support_killer

wandb:
project: "in-context-training"
entity: "hai-trinh220970-ho-chi-minh-city-university-of-technology"
name: "case3_bounded_support"
notes: "Case 3: Bounded Support - w~Exp(1), x~Exp(1) - Sign Constraint"
log_every_steps: 100
43 changes: 43 additions & 0 deletions src/conf/exponential_weighted_regression.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
inherit:
- base.yaml

model:
family: gpt2
n_dims: 20
n_embd: 128
n_head: 8
n_layer: 4
n_positions: 101

training:
task: exponential_weighted_regression
task_kwargs:
rate: 1.0 # exponential distribution rate parameter
scale: 1.0
data: gaussian
data_kwargs: {}
curriculum:
dims:
start: 5
end: 20
inc: 1
interval: 2000
points:
start: 11
end: 41
inc: 2
interval: 2000
batch_size: 64
learning_rate: 0.0001
train_steps: 500001
save_every_steps: 100
keep_every_steps: 10000

out_dir: /content/models/exponential_weighted_regression

wandb:
project: "in-context-training"
entity: "hai-trinh220970-ho-chi-minh-city-university-of-technology"
name: "exponential_weights_experiment"
notes: "Training with exponential-distributed weights (non-uniform on hypersphere)"
log_every_steps: 100
43 changes: 43 additions & 0 deletions src/conf/laplace_weighted_regression.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
inherit:
- base.yaml

model:
family: gpt2
n_dims: 20
n_embd: 128
n_head: 8
n_layer: 4
n_positions: 101

training:
task: laplace_weighted_regression
task_kwargs:
weight_scale: 1.0 # laplace distribution weight scale parameter
scale: 1.0
data: gaussian
data_kwargs: {}
curriculum:
dims:
start: 5
end: 20
inc: 1
interval: 2000
points:
start: 11
end: 41
inc: 2
interval: 2000
batch_size: 64
learning_rate: 0.0001
train_steps: 500001
save_every_steps: 100
keep_every_steps: 10000

out_dir: /content/models/laplace_weighted_regression

wandb:
project: "in-context-training"
entity: "hai-trinh220970-ho-chi-minh-city-university-of-technology"
name: "laplace_weights_experiment"
notes: "Training with laplace-distributed weights (non-uniform on hypersphere)"
log_every_steps: 100
12 changes: 10 additions & 2 deletions src/conf/linear_regression.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@ training:
inc: 2
interval: 2000

out_dir: ../models/linear_regression
# out_dir: ../models/linear_regression
out_dir: D:\Henry-Projects\ChestXray\data\in-context-learning\models\linear_regression


wandb:
name: "linear_regression_standard"
project: "in-context-training"
entity: "hai-trinh220970-ho-chi-minh-city-university-of-technology"
name: "noisy_linear_regression"
notes: "Training with laplace-distributed weights (non-uniform on hypersphere)"
log_every_steps: 100


31 changes: 31 additions & 0 deletions src/conf/lr_wx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
model:
family: gpt2
n_dims: 20
n_embd: 256
n_head: 12
n_layer: 8
n_positions: 101

training:
batch_size: 64
curriculum:
dims:
start: 5
end: 20
inc: 1
interval: 2000
points:
start: 11
end: 41
inc: 2
interval: 2000
learning_rate: 0.0001
train_steps: 500001
data: tstudent # ví dụ: gaussian, uniform, laplace, tstudent, cauchy, poisson, rayleigh
task: linear_regression
task_kwargs:
w_distribution: ${w_distribution} # ví dụ: gaussian, uniform, laplace, tstudent, cauchy, poisson, rayleigh

wandb:
project: in-context-training
name: linear_regression_custom
Loading