Skip to content

Commit ea52c76

Browse files
afrozenatorcopybara-github
authored andcommitted
Gin files for env_server and ppo binaries for learning rate tuning.
PiperOrigin-RevId: 261588529
1 parent 62cf66f commit ea52c76

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import tensor2tensor.trax.inputs
2+
import tensor2tensor.trax.models
3+
import tensor2tensor.trax.optimizers
4+
import tensor2tensor.trax.rlax
5+
import tensor2tensor.trax.rlax.envs
6+
7+
# Parameters for batch_fun:
8+
# ==============================================================================
9+
batch_fun.batch_size = 32
10+
batch_fun.bucket_length = 32
11+
batch_fun.buckets = None
12+
batch_fun.eval_batch_size = 32
13+
14+
# Parameters for inputs:
15+
# ==============================================================================
16+
inputs.data_dir = None
17+
inputs.dataset_name = 'cifar10'
18+
19+
# Parameters for Momentum:
20+
# ==============================================================================
21+
Momentum.mass = 0.9
22+
23+
# Parameters for shuffle_and_batch_data:
24+
# ==============================================================================
25+
shuffle_and_batch_data.preprocess_fun = @trax.inputs.cifar10_no_augmentation_preprocess
26+
27+
# Parameters for WideResnet:
28+
# ==============================================================================
29+
WideResnet.widen_factor = 2
30+
WideResnet.n_blocks = 3
31+
WideResnet.n_output_classes = 10
32+
33+
# Parameters for OnlineTuneEnv:
34+
# ==============================================================================
35+
OnlineTuneEnv.inputs = @trax.inputs.inputs
36+
OnlineTuneEnv.model = @trax.models.WideResnet
37+
OnlineTuneEnv.optimizer = @trax.optimizers.Momentum
38+
OnlineTuneEnv.start_lr = 0.01
39+
OnlineTuneEnv.train_steps = 500
40+
OnlineTuneEnv.eval_steps = 50
41+
OnlineTuneEnv.env_steps = 128
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import tensor2tensor.trax.rlax
2+
3+
# Parameters for ppo.training_loop:
4+
# ==============================================================================
5+
ppo.training_loop.n_optimizer_steps = 30
6+
ppo.training_loop.boundary = 128
7+
ppo.training_loop.max_timestep = 128
8+
ppo.training_loop.max_timestep_eval = 128
9+
ppo.training_loop.random_seed = 0
10+
ppo.training_loop.gamma = 0.99
11+
ppo.training_loop.lambda_ = 0.95
12+
ppo.training_loop.epsilon = 0.1
13+
ppo.training_loop.c1 = 1.0
14+
ppo.training_loop.c2 = 0.01
15+
ppo.training_loop.eval_every_n = 10
16+
ppo.training_loop.done_frac_for_policy_save = 0
17+
ppo.training_loop.enable_early_stopping = True
18+
ppo.training_loop.n_evals = 1
19+
ppo.training_loop.len_history_for_policy = 1 # this needs to be bumped up.
20+
ppo.training_loop.eval_temperatures = (1.0,)
21+
ppo.training_loop.epochs = 1000

0 commit comments

Comments
 (0)