[shardformer] supported bloom model (#4098)

This commit is contained in:
Frank Lee
2023-06-28 15:04:35 +08:00
parent 8af29ee47a
commit b1c2901530
20 changed files with 724 additions and 154 deletions

View File

@@ -3,6 +3,7 @@ from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_global_rank
class Randomizer:
@@ -112,27 +113,90 @@ class Randomizer:
"""
idx = Randomizer._INDEX
Randomizer._INDEX += 1
return idx
@staticmethod
def increment_index():
"""
Increment the index of the randomizer by one.
"""
Randomizer._INDEX += 1
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
@staticmethod
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
"""
Return whether the randomizer index is synchronized across processes.
"""
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
dist.all_gather(gathered_index, index_tensor, process_group)
# make sure all the gathered index are the same
for i in range(1, dist.get_world_size(process_group)):
if gathered_index[i] != gathered_index[0]:
return False
return True
@staticmethod
def synchronize_index(process_group: ProcessGroup = None):
"""
All gather the index and pick the largest value.
"""
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
dist.all_gather(gathered_index, index_tensor, process_group)
# pick the largest index
for i in range(1, dist.get_world_size(process_group)):
if gathered_index[i] > index_tensor:
index_tensor = gathered_index[i]
# set the index
Randomizer._INDEX = index_tensor.item()
def create_randomizer_with_offset(seed: int,
process_group: ProcessGroup = None,
offset_by_rank: bool = True,
offset_by_index: bool = True):
"""
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
Args:
seed (int): The base random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
process_group (ProcessGroup): the process group to get the rank from.
offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True.
offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True.
Returns:
Randomizer: the randomizer with offset.
"""
offset = Randomizer.index()
base_seed = seed
if dist.is_initialized():
if offset_by_rank and dist.is_initialized():
rank = dist.get_rank(process_group)
offset += rank
base_seed += rank
seed += offset
return Randomizer(seed=seed)
if offset_by_index:
# check if the randomizer index is synchronized
is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group)
assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes."
"This is not allowed when we want to create a randomizer with offset by index."
"Please call Randomizer.synchronize_index() first.")
base_seed += Randomizer.index()
Randomizer.increment_index()
return Randomizer(seed=base_seed)