|
| 1 | +import tensor2tensor.trax.inputs |
| 2 | +import tensor2tensor.trax.models |
| 3 | +import tensor2tensor.trax.optimizers |
| 4 | +import tensor2tensor.trax.rlax |
| 5 | +import tensor2tensor.trax.rlax.envs |
| 6 | + |
| 7 | +# Parameters for batch_fun: |
| 8 | +# ============================================================================== |
| 9 | +batch_fun.batch_size = 32 |
| 10 | +batch_fun.bucket_length = 32 |
| 11 | +batch_fun.buckets = None |
| 12 | +batch_fun.eval_batch_size = 32 |
| 13 | + |
| 14 | +# Parameters for inputs: |
| 15 | +# ============================================================================== |
| 16 | +inputs.data_dir = None |
| 17 | +inputs.dataset_name = 'cifar10' |
| 18 | + |
| 19 | +# Parameters for Momentum: |
| 20 | +# ============================================================================== |
| 21 | +Momentum.mass = 0.9 |
| 22 | + |
| 23 | +# Parameters for shuffle_and_batch_data: |
| 24 | +# ============================================================================== |
| 25 | +shuffle_and_batch_data.preprocess_fun = @trax.inputs.cifar10_no_augmentation_preprocess |
| 26 | + |
| 27 | +# Parameters for WideResnet: |
| 28 | +# ============================================================================== |
| 29 | +WideResnet.widen_factor = 2 |
| 30 | +WideResnet.n_blocks = 3 |
| 31 | +WideResnet.n_output_classes = 10 |
| 32 | + |
| 33 | +# Parameters for OnlineTuneEnv: |
| 34 | +# ============================================================================== |
| 35 | +OnlineTuneEnv.inputs = @trax.inputs.inputs |
| 36 | +OnlineTuneEnv.model = @trax.models.WideResnet |
| 37 | +OnlineTuneEnv.optimizer = @trax.optimizers.Momentum |
| 38 | +OnlineTuneEnv.start_lr = 0.01 |
| 39 | +OnlineTuneEnv.train_steps = 500 |
| 40 | +OnlineTuneEnv.eval_steps = 50 |
| 41 | +OnlineTuneEnv.env_steps = 128 |
0 commit comments