1- .. _elastic_train_script :
2-
31Train script
42-------------
53
@@ -9,20 +7,18 @@ working with ``torch.distributed.run`` with these differences:
971. No need to manually pass ``RANK ``, ``WORLD_SIZE ``,
108 ``MASTER_ADDR ``, and ``MASTER_PORT ``.
119
12- 2. ``rdzv_backend `` and ``rdzv_endpoint `` can be provided. For most users
13- this will be set to ``c10d `` (see `rendezvous <rendezvous.html >`_). The default
14- ``rdzv_backend `` creates a non-elastic rendezvous where ``rdzv_endpoint `` holds
15- the master address.
10+ 2. ``rdzv_backend `` and ``rdzv_endpoint `` must be provided. For most users
11+ this will be set to ``c10d `` (see `rendezvous <rendezvous.html >`_).
1612
17133. Make sure you have a ``load_checkpoint(path) `` and
18- ``save_checkpoint(path) `` logic in your script. When any number of
19- workers fail we restart all the workers with the same program
20- arguments so you will lose progress up to the most recent checkpoint
14+ ``save_checkpoint(path) `` logic in your script. When workers fail
15+ we restart all the workers with the same program arguments so you will
16+ lose progress up to the most recent checkpoint
2117 (see `elastic launch <distributed.html >`_).
2218
23194. ``use_env `` flag has been removed. If you were parsing local rank by parsing
2420 the ``--local_rank `` option, you need to get the local rank from the
25- environment variable ``LOCAL_RANK `` (e.g. ``int( os.environ["LOCAL_RANK"]) ``).
21+ environment variable ``LOCAL_RANK `` (e.g. ``os.environ["LOCAL_RANK"] ``).
2622
2723Below is an expository example of a training script that checkpoints on each
2824epoch, hence the worst-case progress lost on failure is one full epoch worth
@@ -35,7 +31,7 @@ of training.
3531 state = load_checkpoint(args.checkpoint_path)
3632 initialize(state)
3733
38- # torch.distributed.run ensures that this will work
34+ # torch.distributed.run ensure that this will work
3935 # by exporting all the env vars needed to initialize the process group
4036 torch.distributed.init_process_group(backend = args.backend)
4137
0 commit comments