[polish] polish ColoTensor and its submodules (#2537)

This commit is contained in:
HELSON 2023-02-03 11:44:10 +08:00 committed by GitHub
parent 51d4d6e718
commit 552183bb74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 65 deletions

View File

@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return tensor return tensor
def __repr__(self): def __repr__(self):
return f'ColoParameter: {ColoTensor.__repr__(self)}' return super(ColoParameter, self).__repr__()
@classmethod @classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None): def __torch_function__(cls, func, types, args=..., kwargs=None):

View File

@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
return _convert_output(ret, colo_spec) return _convert_output(ret, colo_spec)
def __repr__(self): def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}' output_list = [super(ColoTensor, self).__repr__()]
output_list.append(str(self.process_group))
output_list.append(str(self.dist_spec))
if self.compute_spec is not None:
output_list.append(str(self.compute_spec))
return "\n".join(output_list)
def _redistribute(self, dist_spec: _DistSpec) -> None: def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute """_redistribute

View File

@ -9,9 +9,9 @@ class ComputePattern(Enum):
class ComputeSpec(object): class ComputeSpec(object):
"""ComputeSpec """ComputeSpec
The Specification for compuattion pattern The Specification for compuattion pattern
Args: Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern. compute_pattern (ComputePattern): an Enum instance for compute pattern.
""" """
@ -23,7 +23,7 @@ class ComputeSpec(object):
self.output_replicate = True self.output_replicate = True
def __repr__(self): def __repr__(self):
return f'Compute pattern: {self.compute_pattern}' return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
def set_output_replicate(self, flag: bool = True): def set_output_replicate(self, flag: bool = True):
self.output_replicate = flag self.output_replicate = flag

View File

@ -11,7 +11,7 @@ class DistPlacementPattern(Enum):
class _DistSpec: class _DistSpec:
"""_DistSpec """_DistSpec
A class indicates Distributed Specification. A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups. The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced. Because the dist spec of data parallel process group can be automatically deduced.
@ -39,11 +39,12 @@ class _DistSpec:
return True return True
def __repr__(self) -> str: def __repr__(self) -> str:
res_list = ["DistSpec:"] attr_list = []
for attr in dir(self): for attr in dir(self):
if not attr.startswith('__'): if not attr.startswith('__'):
res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}') attr_list.append(f'{attr}={str(getattr(self, attr))}')
return ''.join(res_list) attr_str = ", ".join(attr_list)
return "DistSpec(" + attr_str + ")"
def ReplicaSpec() -> _DistSpec: def ReplicaSpec() -> _DistSpec:

View File

@ -1,29 +1,36 @@
import torch
from typing import List, Optional from typing import List, Optional
from colossalai.logging import get_dist_logger
import torch
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger
class PyTorchProcessGroupDict(metaclass=SingletonMeta): class PyTorchProcessGroupDict(metaclass=SingletonMeta):
def __init__(self): def __init__(self):
# distributed settings # distributed settings
# use this dict to record all Pytorch ProcessGroups
self.dict = {} self.dict = {}
# set a distributed logger
self.logger = get_dist_logger('ProcessGroup')
def log_pg_init(self, rank_list: List[int], backend: str):
str_list = ["Pytorch ProcessGroup Init:"]
str_list.append(f"backend: {backend}")
str_list.append(f"ranks: {rank_list}")
self.logger.info("\n\t".join(str_list), ranks=[0])
def get(self, rank_list: List[int], backend: str = 'nccl'): def get(self, rank_list: List[int], backend: str = 'nccl'):
"""Reuse Pytorch ProcessGroup when such a group is initialized """Reuse Pytorch ProcessGroup when such a group is initialized
""" """
rank_tuple = tuple(rank_list)
# we need to convert the passed list to a tuple # we need to convert the passed list to a tuple
# since List is unhashable # since List is unhashable
pg_key = (backend, rank_tuple) processgroup_key = (backend, tuple(rank_list))
if processgroup_key not in self.dict:
if pg_key not in self.dict: self.log_pg_init(rank_list=rank_list, backend=backend)
self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
self.logger = get_dist_logger('ProcessGroup') return self.dict[processgroup_key]
self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
return self.dict[pg_key]
PYTORCHPGDICT_ = PyTorchProcessGroupDict() PYTORCHPGDICT_ = PyTorchProcessGroupDict()
@ -40,7 +47,7 @@ class ProcessGroup:
rank: the global rank of the current process. rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group. ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group. backend: str, the backend of the process group.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1. tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks). dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
""" """
@ -54,10 +61,10 @@ class ProcessGroup:
return return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if rank is None:
self._rank = torch.distributed.get_rank() self._rank = torch.distributed.get_rank()
else: if rank is not None:
self._rank = rank assert self._rank == rank # make sure that the global rank is correct
if ranks is None: if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size())) self._rank_list = list(range(torch.distributed.get_world_size()))
@ -104,7 +111,7 @@ class ProcessGroup:
self.is_init = True self.is_init = True
def set_cpu_groups(self): def set_cpu_groups(self):
"""set_cpu_groups """set_cpu_groups
Initialize Pytorch process groups for cpu communications. Initialize Pytorch process groups for cpu communications.
""" """
if self.has_cpu_groups: if self.has_cpu_groups:
@ -122,7 +129,7 @@ class ProcessGroup:
@property @property
def has_cpu_groups(self) -> bool: def has_cpu_groups(self) -> bool:
"""has_cpu_groups """has_cpu_groups
If cpu groups have been initailized. If cpu groups have been initailized.
Returns: Returns:
@ -132,8 +139,9 @@ class ProcessGroup:
def __repr__(self): def __repr__(self):
if self.is_init: if self.is_init:
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\ ranks_str = f"ProcessGroup(ranks={self._rank_list},\n"
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list) personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})"
return ranks_str + personal_str
else: else:
return "ProcessGroup not initialized" return "ProcessGroup not initialized"
@ -155,7 +163,7 @@ class ProcessGroup:
return True return True
def rank(self) -> int: def rank(self) -> int:
"""rank """rank
The current rank in the global process group. The current rank in the global process group.
@ -165,9 +173,9 @@ class ProcessGroup:
return self._rank return self._rank
def ranks_in_group(self) -> List[int]: def ranks_in_group(self) -> List[int]:
"""ranks_in_group """ranks_in_group
a list of rank number in in the global process group. a list of rank number in in the global process group.
Returns: Returns:
List[int]: a list of rank number. List[int]: a list of rank number.
@ -177,7 +185,7 @@ class ProcessGroup:
def world_size(self) -> int: def world_size(self) -> int:
"""world_size """world_size
The world size of the global process group. The world size of the global process group.
Returns: Returns:
int: world size int: world size
@ -185,7 +193,7 @@ class ProcessGroup:
return self._world_size return self._world_size
def tp_rank_list(self) -> List[int]: def tp_rank_list(self) -> List[int]:
"""tp_rank_list """tp_rank_list
the rank list in the TP process group containing the current rank. the rank list in the TP process group containing the current rank.
@ -195,7 +203,7 @@ class ProcessGroup:
return self._tp_rank_list return self._tp_rank_list
def dp_rank_list(self) -> List[int]: def dp_rank_list(self) -> List[int]:
"""dp_rank_list """dp_rank_list
the rank list in the DP process group containing the current rank. the rank list in the DP process group containing the current rank.
@ -205,7 +213,7 @@ class ProcessGroup:
return self._dp_rank_list return self._dp_rank_list
def tp_local_rank(self) -> int: def tp_local_rank(self) -> int:
"""tp_local_rank """tp_local_rank
The local rank number in the current TP process group. The local rank number in the current TP process group.
@ -268,7 +276,7 @@ class ProcessGroup:
"""cpu_dp_process_group """cpu_dp_process_group
the pytorch CPU DP process group containing the current rank. the pytorch CPU DP process group containing the current rank.
assert failed if cpu process group is not initialized. assert failed if cpu process group is not initialized.
Returns: Returns:
@ -281,7 +289,7 @@ class ProcessGroup:
"""cpu_tp_process_group """cpu_tp_process_group
the pytorch CPU TP process group containing the current rank. the pytorch CPU TP process group containing the current rank.
assert failed if cpu process group is not initialized. assert failed if cpu process group is not initialized.
Returns: Returns:

View File

@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
# detaching tensor is necessary for optimizers. # detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad requires_grad = param.requires_grad
# param is the global tensor. # param is the global tensor.
if param.device.type == "meta": if param.device.type == "meta":
colo_param = ColoParameter(param, requires_grad=requires_grad) colo_param = ColoParameter(param, requires_grad=requires_grad)
else: else:
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad) colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization. # if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization. # This can reduce the model size after initialization.
@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr(submodule, param_name) delattr(submodule, param_name)
setattr(submodule, param_name, colo_param) setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule) colo_param.shared_param_modules.append(submodule)
meta_param_flag = 0 param_number = 0
meta_buffer_flag = 0 meta_param_number = 0
buffer_number = 0
meta_buffer_number = 0
for param in module.parameters(): for param in module.parameters():
if param.device.type=="meta": param_number += 1
meta_param_flag = 1 meta_param_number += (param.device.type == 'meta')
if meta_param_flag == 1 and param.device.type!="meta":
raise ValueError("Meta parameters and valued parameters can not be in the same model")
for buffer in module.buffers(): for buffer in module.buffers():
if buffer.device.type=="meta": buffer_number += 1
meta_buffer_flag = 1 meta_buffer_number += (buffer.device.type == 'meta')
if meta_buffer_flag == 1 and buffer.device.type!="meta":
raise ValueError("Meta buffers and valued buffers can not be in the same model") if meta_param_number > 0 and meta_param_number != param_number:
raise ValueError("Meta parameters and valued parameters can not be in the same model")
if meta_param_flag==1 and meta_buffer_flag==1: if meta_buffer_number > 0 and meta_buffer_number != buffer_number:
pass raise ValueError("Meta buffers and valued buffers can not be in the same model")
elif meta_buffer_flag==0 and meta_param_flag==1:
for name, buf in module.named_buffers(): if meta_buffer_number == 0:
module._buffers[name] = module._buffers[name].to(device=self._device) for buffer in module.buffers():
elif meta_param_flag==0 and meta_buffer_flag==1: buffer.data = buffer.data.to(device=self._device)
for name, param in module.named_parameters():
module._parameters[name] = module._parameters[name].to(device=self._device)
else:
module.to(self._device)
def post_process_colo_init_ctx(model: torch.nn.Module, def post_process_colo_init_ctx(model: torch.nn.Module,
device: torch.device = torch.device('cpu'), device: torch.device = torch.device('cpu'),