From cec05b25c9e34a15190f528776dea0d99f9ed875 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 8 Mar 2022 14:45:01 +0800 Subject: [PATCH] [zero] update zero context init with the updated test utils (#327) --- .../engine/ophooks/_memtracer_ophook.py | 16 +++++----- colossalai/zero/init_ctx/init_context.py | 29 ++++++++++++------- .../zero/shard_utils/base_shard_strategy.py | 11 ++++--- .../zero/sharded_param/sharded_param.py | 16 ++++++++-- tests/components_to_test/nested_model.py | 19 ++++++++---- .../repeated_computed_layer.py | 12 ++++++-- tests/components_to_test/resnet.py | 12 ++++++-- tests/test_engine/test_engine.py | 5 ++-- .../test_init_context.py | 23 +++++++++------ .../test_shard_param.py | 2 ++ 10 files changed, 96 insertions(+), 49 deletions(-) diff --git a/colossalai/engine/ophooks/_memtracer_ophook.py b/colossalai/engine/ophooks/_memtracer_ophook.py index 16d112550..663d3d8a9 100644 --- a/colossalai/engine/ophooks/_memtracer_ophook.py +++ b/colossalai/engine/ophooks/_memtracer_ophook.py @@ -1,4 +1,3 @@ -from re import S from colossalai.context.parallel_mode import ParallelMode import torch from . import BaseOpHook @@ -7,7 +6,7 @@ from colossalai.registry import OPHOOKS from colossalai.logging import get_dist_logger from time import sleep, time import pickle -from typing import Union, Optional +from typing import Optional from colossalai.core import global_context as gpc @@ -19,12 +18,13 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int: """ ret: int = torch.cuda.memory_allocated(device) # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats(device) return ret class AsyncMemoryMonitor: + def __init__(self, power=10): """ An Async Mem Monitor runing during computing. @@ -81,7 +81,7 @@ class AsyncMemoryMonitor: def save(self, filename): with open(filename, "wb") as f: pickle.dump(self.state_dict(), f) - + def clear(self): self.mem_stats.clear() self.time_stamps.clear() @@ -92,7 +92,7 @@ class MemTracerOpHook(BaseOpHook): ''' Collect GPU memory usage information - Args: + Args: warmup (int): This parameter indicates how many iterations to truncate before profiling, e.g. set to 5 and the data will start from 6-th iteration refreshrate (int): This parameter decides the frequency of write file. @@ -106,6 +106,7 @@ class MemTracerOpHook(BaseOpHook): _data_prefix (string): the prefix of the stats data file _rank (int): the rank of current node ''' + def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"): super().__init__() self.async_mem_monitor = AsyncMemoryMonitor() @@ -128,7 +129,7 @@ class MemTracerOpHook(BaseOpHook): @property def refreshrate(self) -> int: return self._refreshrate - + @property def warmup(self) -> int: return self._warmup @@ -178,8 +179,7 @@ class MemTracerOpHook(BaseOpHook): # every `refreshrate` times, refresh the file if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0: # output file info - self._logger.info( - f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl') + self._logger.info(f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl') self.save_results() self._count += 1 self._logger.debug(f'data file has been refreshed {self._count} times') diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index d7bd82c27..70818ad33 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -82,25 +82,31 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): 3. Shard the param and grad according to flags. """ - def __init__( - self, - convert_fp16: bool, - convert_cuda: bool, - shard_strategy: BaseShardStrategy, - shard_param: bool = False, - shard_grad: bool = False, - ): + def __init__(self, + convert_fp16: bool, + convert_cuda: bool, + shard_strategy: BaseShardStrategy, + shard_param: bool = False, + shard_grad: bool = False, + rm_torch_payload_on_the_fly=False): super().__init__() self.convert_fp16 = convert_fp16 self.convert_cuda = convert_cuda self.shard_param = shard_param self.shard_grad = shard_grad self.shard_strategy = shard_strategy + self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly + self.initialized_param_list = [] def _post_context_exec(self): """The callback function when the context exits. """ - pass + if not self.rm_torch_payload_on_the_fly: + for param in self.initialized_param_list: + assert hasattr(param, 'ca_attr') + param.ca_attr.remove_torch_payload() + + del self.initialized_param_list def _post_init_method(self, module): r"""The function to call at the end of the constructor of each nn.Module. @@ -121,7 +127,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if param.grad is not None: param.grad = param.grad.to(torch.half).to(target_device) - param.ca_attr = ShardedParamV2(param) + param.ca_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) + + self.initialized_param_list.append(param) + if self.shard_param: self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor]) if param.ca_attr.grad and self.shard_grad: diff --git a/colossalai/zero/shard_utils/base_shard_strategy.py b/colossalai/zero/shard_utils/base_shard_strategy.py index e3f57eca4..ddae476cc 100644 --- a/colossalai/zero/shard_utils/base_shard_strategy.py +++ b/colossalai/zero/shard_utils/base_shard_strategy.py @@ -7,6 +7,11 @@ from typing import List, Optional class BaseShardStrategy(ABC): def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None: + """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. + + Args: + process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None. + """ self.process_group = process_group self.world_size = dist.get_world_size(self.process_group) self.local_rank = dist.get_rank(self.process_group) @@ -14,14 +19,8 @@ class BaseShardStrategy(ABC): @abstractmethod def shard(self, tensor_list: List[ShardedTensor]): - r""" - sharded the memory of tensor on multiple processes. - """ pass @abstractmethod def gather(self, tensor_list: List[ShardedTensor]): - r""" - duplicate tensor payload on each processes. - """ pass diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 1358bcc3a..b050430a9 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -10,7 +10,10 @@ from typing import Union, Tuple, Optional class ShardedParamV2(object): - def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None: + def __init__(self, + param: torch.nn.Parameter, + process_group: Optional[dist.ProcessGroup] = None, + rm_torch_payload=False) -> None: self._data_sharded_tensor = ShardedTensor(param.data, process_group) if param.requires_grad and param.grad is not None: self._grad_sharded_tensor = ShardedTensor(param.grad, process_group) @@ -19,7 +22,16 @@ class ShardedParamV2(object): self._grad_sharded_tensor = None # make sure the shared param is the only owner of payload - param.data = torch.empty([], dtype=param.dtype, device=param.device) + # The param.data maybe used to init the other part of the model. + # For example: File "resnet.py", line 190, in __init__ + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # So we can not empty the .data at this time + self.param = param + if rm_torch_payload: + self.remove_torch_payload() + + def remove_torch_payload(self): + self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device) @property def data(self): diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py index 316c1e18c..5f32b08e9 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/components_to_test/nested_model.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from colossalai.nn import CheckpointModule from .utils import DummyDataGenerator from .registry import non_distributed_component_funcs @@ -15,10 +16,10 @@ class SubNet(nn.Module): return F.linear(x, weight, self.bias) -class NestedNet(nn.Module): +class NestedNet(CheckpointModule): - def __init__(self) -> None: - super().__init__() + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint) self.fc1 = nn.Linear(5, 5) self.sub_fc = SubNet(5) self.fc2 = nn.Linear(5, 2) @@ -41,9 +42,15 @@ class DummyDataLoader(DummyDataGenerator): @non_distributed_component_funcs.register(name='nested_model') def get_training_components(): - model = NestedNet() + + def model_builder(checkpoint): + return NestedNet(checkpoint) + trainloader = DummyDataLoader() testloader = DummyDataLoader() - optim = torch.optim.Adam(model.parameters(), lr=0.001) + + def optim_builder(model): + return torch.optim.Adam(model.parameters(), lr=0.001) + criterion = torch.nn.CrossEntropyLoss() - return model, trainloader, testloader, optim, criterion + return model_builder, trainloader, testloader, optim_builder, criterion diff --git a/tests/components_to_test/repeated_computed_layer.py b/tests/components_to_test/repeated_computed_layer.py index a0f742041..bc035d4b5 100644 --- a/tests/components_to_test/repeated_computed_layer.py +++ b/tests/components_to_test/repeated_computed_layer.py @@ -36,9 +36,15 @@ class DummyDataLoader(DummyDataGenerator): @non_distributed_component_funcs.register(name='repeated_computed_layers') def get_training_components(): - model = NetWithRepeatedlyComputedLayers(checkpoint=True) + + def model_builder(checkpoint=True): + return NetWithRepeatedlyComputedLayers(checkpoint) + trainloader = DummyDataLoader() testloader = DummyDataLoader() - optim = torch.optim.Adam(model.parameters(), lr=0.001) + + def optim_builder(model): + return torch.optim.Adam(model.parameters(), lr=0.001) + criterion = torch.nn.CrossEntropyLoss() - return model, trainloader, testloader, optim, criterion + return model_builder, trainloader, testloader, optim_builder, criterion diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py index f1448fa62..20a4be8e2 100644 --- a/tests/components_to_test/resnet.py +++ b/tests/components_to_test/resnet.py @@ -22,9 +22,15 @@ def get_cifar10_dataloader(train): @non_distributed_component_funcs.register(name='resnet18') def get_resnet_training_components(): - model = resnet18(num_classes=10) + + def model_builder(checkpoint=False): + return resnet18(num_classes=10) + trainloader = get_cifar10_dataloader(train=True) testloader = get_cifar10_dataloader(train=False) - optim = torch.optim.Adam(model.parameters(), lr=0.001) + + def optim_builder(model): + return torch.optim.Adam(model.parameters(), lr=0.001) + criterion = torch.nn.CrossEntropyLoss() - return model, trainloader, testloader, optim, criterion + return model_builder, trainloader, testloader, optim_builder, criterion diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py index f6aa0a6e3..53c3ebd3e 100644 --- a/tests/test_engine/test_engine.py +++ b/tests/test_engine/test_engine.py @@ -16,10 +16,11 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None def run_train(): for get_components_func in non_distributed_component_funcs: - model, train_dataloader, _, optimizer, criterion = get_components_func() + model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func() + model = model_builder(checkpoint=False) engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer, + optimizer=optimizer_builder(model), criterion=criterion, train_dataloader=train_dataloader) diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index cf038844c..b3f214d98 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -9,22 +9,27 @@ import torch import torch.multiprocessing as mp from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy from colossalai.zero.init_ctx import ZeroInitContext -from common import CONFIG, Net +from common import CONFIG from colossalai.utils import free_port +from tests.components_to_test.registry import non_distributed_component_funcs def run_dist(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=TensorShardStrategy(), shard_param=True): - # Note Net(checkpoint=True).cuda() moving to cuda is useless - model = Net(checkpoint=True) + for get_components_func in non_distributed_component_funcs: + model_builder, _, _, _, _ = get_components_func() + with ZeroInitContext(convert_fp16=True, + convert_cuda=True, + shard_strategy=TensorShardStrategy(), + shard_param=True): + model = model_builder(checkpoint=True) - for param in model.parameters(): - assert hasattr(param, 'ca_attr') - assert param.ca_attr.data.dtype == torch.half - assert param.ca_attr._data_sharded_tensor.is_sharded - assert param.ca_attr.data.device.type == 'cuda' + for param in model.parameters(): + assert hasattr(param, 'ca_attr') + assert param.ca_attr.data.dtype == torch.half + assert param.ca_attr._data_sharded_tensor.is_sharded + assert param.ca_attr.data.device.type == 'cuda' @pytest.mark.dist diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 640292f31..bd04db77c 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -46,6 +46,8 @@ def _run_shard_param_v2(rank, world_size, port): sparam = ShardedParamV2(param=param, process_group=None) allclose(sparam.data, param_ref.data) + + sparam.remove_torch_payload() assert (param.data.numel() == 1)