forked from konpatp/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdist_utils.py
More file actions
executable file
·42 lines (30 loc) · 804 Bytes
/
dist_utils.py
File metadata and controls
executable file
·42 lines (30 loc) · 804 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import List
from torch import distributed
def barrier():
if distributed.is_initialized():
distributed.barrier()
else:
pass
def broadcast(data, src):
if distributed.is_initialized():
distributed.broadcast(data, src)
else:
pass
def all_gather(data: List, src):
if distributed.is_initialized():
distributed.all_gather(data, src)
else:
data[0] = src
def get_rank():
if distributed.is_initialized():
return distributed.get_rank()
else:
return 0
def get_world_size():
if distributed.is_initialized():
return distributed.get_world_size()
else:
return 1
def chunk_size(size, rank, world_size):
extra = rank < size % world_size
return size // world_size + extra