diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 92220d9e2..b384579fe 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): return tensor def __repr__(self): - return f'ColoParameter: {ColoTensor.__repr__(self)}' + return super(ColoParameter, self).__repr__() @classmethod def __torch_function__(cls, func, types, args=..., kwargs=None): diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index b27f5dea7..474dc7a1e 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor): return _convert_output(ret, colo_spec) def __repr__(self): - return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}' + output_list = [super(ColoTensor, self).__repr__()] + output_list.append(str(self.process_group)) + output_list.append(str(self.dist_spec)) + if self.compute_spec is not None: + output_list.append(str(self.compute_spec)) + return "\n".join(output_list) def _redistribute(self, dist_spec: _DistSpec) -> None: """_redistribute diff --git a/colossalai/tensor/compute_spec.py b/colossalai/tensor/compute_spec.py index a9774c34c..73328285e 100644 --- a/colossalai/tensor/compute_spec.py +++ b/colossalai/tensor/compute_spec.py @@ -9,9 +9,9 @@ class ComputePattern(Enum): class ComputeSpec(object): - """ComputeSpec + """ComputeSpec The Specification for compuattion pattern - + Args: compute_pattern (ComputePattern): an Enum instance for compute pattern. """ @@ -23,7 +23,7 @@ class ComputeSpec(object): self.output_replicate = True def __repr__(self): - return f'Compute pattern: {self.compute_pattern}' + return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})' def set_output_replicate(self, flag: bool = True): self.output_replicate = flag diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py index 0b62cbdda..8dd0d8791 100644 --- a/colossalai/tensor/distspec.py +++ b/colossalai/tensor/distspec.py @@ -11,7 +11,7 @@ class DistPlacementPattern(Enum): class _DistSpec: """_DistSpec - + A class indicates Distributed Specification. The DistSpec is only works for the tensor parallel process groups. Because the dist spec of data parallel process group can be automatically deduced. @@ -39,11 +39,12 @@ class _DistSpec: return True def __repr__(self) -> str: - res_list = ["DistSpec:"] + attr_list = [] for attr in dir(self): if not attr.startswith('__'): - res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}') - return ''.join(res_list) + attr_list.append(f'{attr}={str(getattr(self, attr))}') + attr_str = ", ".join(attr_list) + return "DistSpec(" + attr_str + ")" def ReplicaSpec() -> _DistSpec: diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index e7e565071..f108bdc24 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -1,29 +1,36 @@ -import torch from typing import List, Optional -from colossalai.logging import get_dist_logger + +import torch + from colossalai.context.singleton_meta import SingletonMeta +from colossalai.logging import get_dist_logger class PyTorchProcessGroupDict(metaclass=SingletonMeta): def __init__(self): # distributed settings + # use this dict to record all Pytorch ProcessGroups self.dict = {} + # set a distributed logger + self.logger = get_dist_logger('ProcessGroup') + + def log_pg_init(self, rank_list: List[int], backend: str): + str_list = ["Pytorch ProcessGroup Init:"] + str_list.append(f"backend: {backend}") + str_list.append(f"ranks: {rank_list}") + self.logger.info("\n\t".join(str_list), ranks=[0]) def get(self, rank_list: List[int], backend: str = 'nccl'): """Reuse Pytorch ProcessGroup when such a group is initialized """ - rank_tuple = tuple(rank_list) # we need to convert the passed list to a tuple # since List is unhashable - pg_key = (backend, rank_tuple) - - if pg_key not in self.dict: - - self.logger = get_dist_logger('ProcessGroup') - self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0]) - self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend) - return self.dict[pg_key] + processgroup_key = (backend, tuple(rank_list)) + if processgroup_key not in self.dict: + self.log_pg_init(rank_list=rank_list, backend=backend) + self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend) + return self.dict[processgroup_key] PYTORCHPGDICT_ = PyTorchProcessGroupDict() @@ -40,7 +47,7 @@ class ProcessGroup: rank: the global rank of the current process. ranks: List[int], a list of rank id belongings to this process group. backend: str, the backend of the process group. - tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1. + tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1. dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks). """ @@ -54,10 +61,10 @@ class ProcessGroup: return assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" - if rank is None: - self._rank = torch.distributed.get_rank() - else: - self._rank = rank + + self._rank = torch.distributed.get_rank() + if rank is not None: + assert self._rank == rank # make sure that the global rank is correct if ranks is None: self._rank_list = list(range(torch.distributed.get_world_size())) @@ -104,7 +111,7 @@ class ProcessGroup: self.is_init = True def set_cpu_groups(self): - """set_cpu_groups + """set_cpu_groups Initialize Pytorch process groups for cpu communications. """ if self.has_cpu_groups: @@ -122,7 +129,7 @@ class ProcessGroup: @property def has_cpu_groups(self) -> bool: - """has_cpu_groups + """has_cpu_groups If cpu groups have been initailized. Returns: @@ -132,8 +139,9 @@ class ProcessGroup: def __repr__(self): if self.is_init: - return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\ - format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list) + ranks_str = f"ProcessGroup(ranks={self._rank_list},\n" + personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})" + return ranks_str + personal_str else: return "ProcessGroup not initialized" @@ -155,7 +163,7 @@ class ProcessGroup: return True def rank(self) -> int: - """rank + """rank The current rank in the global process group. @@ -165,9 +173,9 @@ class ProcessGroup: return self._rank def ranks_in_group(self) -> List[int]: - """ranks_in_group + """ranks_in_group - a list of rank number in in the global process group. + a list of rank number in in the global process group. Returns: List[int]: a list of rank number. @@ -177,7 +185,7 @@ class ProcessGroup: def world_size(self) -> int: """world_size - The world size of the global process group. + The world size of the global process group. Returns: int: world size @@ -185,7 +193,7 @@ class ProcessGroup: return self._world_size def tp_rank_list(self) -> List[int]: - """tp_rank_list + """tp_rank_list the rank list in the TP process group containing the current rank. @@ -195,7 +203,7 @@ class ProcessGroup: return self._tp_rank_list def dp_rank_list(self) -> List[int]: - """dp_rank_list + """dp_rank_list the rank list in the DP process group containing the current rank. @@ -205,7 +213,7 @@ class ProcessGroup: return self._dp_rank_list def tp_local_rank(self) -> int: - """tp_local_rank + """tp_local_rank The local rank number in the current TP process group. @@ -268,7 +276,7 @@ class ProcessGroup: """cpu_dp_process_group the pytorch CPU DP process group containing the current rank. - + assert failed if cpu process group is not initialized. Returns: @@ -281,7 +289,7 @@ class ProcessGroup: """cpu_tp_process_group the pytorch CPU TP process group containing the current rank. - + assert failed if cpu process group is not initialized. Returns: diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 93c91e099..ab354ea70 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter, # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad # param is the global tensor. - + if param.device.type == "meta": colo_param = ColoParameter(param, requires_grad=requires_grad) - else: + else: colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad) - # if default_shard_plan exists, shard the param during initialization. # This can reduce the model size after initialization. @@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): delattr(submodule, param_name) setattr(submodule, param_name, colo_param) colo_param.shared_param_modules.append(submodule) - - meta_param_flag = 0 - meta_buffer_flag = 0 + + param_number = 0 + meta_param_number = 0 + buffer_number = 0 + meta_buffer_number = 0 + for param in module.parameters(): - if param.device.type=="meta": - meta_param_flag = 1 - if meta_param_flag == 1 and param.device.type!="meta": - raise ValueError("Meta parameters and valued parameters can not be in the same model") - + param_number += 1 + meta_param_number += (param.device.type == 'meta') + for buffer in module.buffers(): - if buffer.device.type=="meta": - meta_buffer_flag = 1 - if meta_buffer_flag == 1 and buffer.device.type!="meta": - raise ValueError("Meta buffers and valued buffers can not be in the same model") - - if meta_param_flag==1 and meta_buffer_flag==1: - pass - elif meta_buffer_flag==0 and meta_param_flag==1: - for name, buf in module.named_buffers(): - module._buffers[name] = module._buffers[name].to(device=self._device) - elif meta_param_flag==0 and meta_buffer_flag==1: - for name, param in module.named_parameters(): - module._parameters[name] = module._parameters[name].to(device=self._device) - else: - module.to(self._device) - + buffer_number += 1 + meta_buffer_number += (buffer.device.type == 'meta') + + if meta_param_number > 0 and meta_param_number != param_number: + raise ValueError("Meta parameters and valued parameters can not be in the same model") + if meta_buffer_number > 0 and meta_buffer_number != buffer_number: + raise ValueError("Meta buffers and valued buffers can not be in the same model") + + if meta_buffer_number == 0: + for buffer in module.buffers(): + buffer.data = buffer.data.to(device=self._device) + def post_process_colo_init_ctx(model: torch.nn.Module, device: torch.device = torch.device('cpu'),