mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer] supported bloom model (#4098)
This commit is contained in:
@@ -3,6 +3,7 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_global_rank
|
||||
|
||||
|
||||
class Randomizer:
|
||||
@@ -112,27 +113,90 @@ class Randomizer:
|
||||
|
||||
"""
|
||||
idx = Randomizer._INDEX
|
||||
Randomizer._INDEX += 1
|
||||
return idx
|
||||
|
||||
@staticmethod
|
||||
def increment_index():
|
||||
"""
|
||||
Increment the index of the randomizer by one.
|
||||
"""
|
||||
Randomizer._INDEX += 1
|
||||
|
||||
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
|
||||
@staticmethod
|
||||
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
|
||||
"""
|
||||
Return whether the randomizer index is synchronized across processes.
|
||||
"""
|
||||
index = Randomizer.index()
|
||||
if dist.is_initialized():
|
||||
# convert the index to tensor
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
||||
|
||||
# all gather the index
|
||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||
dist.all_gather(gathered_index, index_tensor, process_group)
|
||||
|
||||
# make sure all the gathered index are the same
|
||||
for i in range(1, dist.get_world_size(process_group)):
|
||||
if gathered_index[i] != gathered_index[0]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def synchronize_index(process_group: ProcessGroup = None):
|
||||
"""
|
||||
All gather the index and pick the largest value.
|
||||
"""
|
||||
index = Randomizer.index()
|
||||
|
||||
if dist.is_initialized():
|
||||
# convert the index to tensor
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
||||
|
||||
# all gather the index
|
||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||
dist.all_gather(gathered_index, index_tensor, process_group)
|
||||
|
||||
# pick the largest index
|
||||
for i in range(1, dist.get_world_size(process_group)):
|
||||
if gathered_index[i] > index_tensor:
|
||||
index_tensor = gathered_index[i]
|
||||
|
||||
# set the index
|
||||
Randomizer._INDEX = index_tensor.item()
|
||||
|
||||
|
||||
def create_randomizer_with_offset(seed: int,
|
||||
process_group: ProcessGroup = None,
|
||||
offset_by_rank: bool = True,
|
||||
offset_by_index: bool = True):
|
||||
"""
|
||||
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
|
||||
|
||||
Args:
|
||||
seed (int): The base random seed to set.
|
||||
enable_cpu (bool): fork the CPU RNG state as well.
|
||||
process_group (ProcessGroup): the process group to get the rank from.
|
||||
offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True.
|
||||
offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True.
|
||||
|
||||
Returns:
|
||||
Randomizer: the randomizer with offset.
|
||||
"""
|
||||
offset = Randomizer.index()
|
||||
base_seed = seed
|
||||
|
||||
if dist.is_initialized():
|
||||
if offset_by_rank and dist.is_initialized():
|
||||
rank = dist.get_rank(process_group)
|
||||
offset += rank
|
||||
base_seed += rank
|
||||
|
||||
seed += offset
|
||||
return Randomizer(seed=seed)
|
||||
if offset_by_index:
|
||||
# check if the randomizer index is synchronized
|
||||
is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group)
|
||||
assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes."
|
||||
"This is not allowed when we want to create a randomizer with offset by index."
|
||||
"Please call Randomizer.synchronize_index() first.")
|
||||
|
||||
base_seed += Randomizer.index()
|
||||
Randomizer.increment_index()
|
||||
|
||||
return Randomizer(seed=base_seed)
|
||||
|
Reference in New Issue
Block a user