[zero] update zero context init with the updated test utils (#327)

This commit is contained in:
Jiarui Fang
2022-03-08 14:45:01 +08:00
committed by Frank Lee
parent 6268446b81
commit 11bddb6e55
10 changed files with 96 additions and 49 deletions

View File

@@ -1,4 +1,3 @@
from re import S
from colossalai.context.parallel_mode import ParallelMode
import torch
from . import BaseOpHook
@@ -7,7 +6,7 @@ from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from time import sleep, time
import pickle
from typing import Union, Optional
from typing import Optional
from colossalai.core import global_context as gpc
@@ -19,12 +18,13 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
"""
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
class AsyncMemoryMonitor:
def __init__(self, power=10):
"""
An Async Mem Monitor runing during computing.
@@ -81,7 +81,7 @@ class AsyncMemoryMonitor:
def save(self, filename):
with open(filename, "wb") as f:
pickle.dump(self.state_dict(), f)
def clear(self):
self.mem_stats.clear()
self.time_stamps.clear()
@@ -92,7 +92,7 @@ class MemTracerOpHook(BaseOpHook):
'''
Collect GPU memory usage information
Args:
Args:
warmup (int): This parameter indicates how many iterations to truncate
before profiling, e.g. set to 5 and the data will start from 6-th iteration
refreshrate (int): This parameter decides the frequency of write file.
@@ -106,6 +106,7 @@ class MemTracerOpHook(BaseOpHook):
_data_prefix (string): the prefix of the stats data file
_rank (int): the rank of current node
'''
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
super().__init__()
self.async_mem_monitor = AsyncMemoryMonitor()
@@ -128,7 +129,7 @@ class MemTracerOpHook(BaseOpHook):
@property
def refreshrate(self) -> int:
return self._refreshrate
@property
def warmup(self) -> int:
return self._warmup
@@ -178,8 +179,7 @@ class MemTracerOpHook(BaseOpHook):
# every `refreshrate` times, refresh the file
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
# output file info
self._logger.info(
f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
self._logger.info(f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
self.save_results()
self._count += 1
self._logger.debug(f'data file has been refreshed {self._count} times')

View File

@@ -82,25 +82,31 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
"""
def __init__(
self,
convert_fp16: bool,
convert_cuda: bool,
shard_strategy: BaseShardStrategy,
shard_param: bool = False,
shard_grad: bool = False,
):
def __init__(self,
convert_fp16: bool,
convert_cuda: bool,
shard_strategy: BaseShardStrategy,
shard_param: bool = False,
shard_grad: bool = False,
rm_torch_payload_on_the_fly=False):
super().__init__()
self.convert_fp16 = convert_fp16
self.convert_cuda = convert_cuda
self.shard_param = shard_param
self.shard_grad = shard_grad
self.shard_strategy = shard_strategy
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
self.initialized_param_list = []
def _post_context_exec(self):
"""The callback function when the context exits.
"""
pass
if not self.rm_torch_payload_on_the_fly:
for param in self.initialized_param_list:
assert hasattr(param, 'ca_attr')
param.ca_attr.remove_torch_payload()
del self.initialized_param_list
def _post_init_method(self, module):
r"""The function to call at the end of the constructor of each nn.Module.
@@ -121,7 +127,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None:
param.grad = param.grad.to(torch.half).to(target_device)
param.ca_attr = ShardedParamV2(param)
param.ca_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
self.initialized_param_list.append(param)
if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor])
if param.ca_attr.grad and self.shard_grad:

View File

@@ -7,6 +7,11 @@ from typing import List, Optional
class BaseShardStrategy(ABC):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
Args:
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
"""
self.process_group = process_group
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
@@ -14,14 +19,8 @@ class BaseShardStrategy(ABC):
@abstractmethod
def shard(self, tensor_list: List[ShardedTensor]):
r"""
sharded the memory of tensor on multiple processes.
"""
pass
@abstractmethod
def gather(self, tensor_list: List[ShardedTensor]):
r"""
duplicate tensor payload on each processes.
"""
pass

View File

@@ -10,7 +10,10 @@ from typing import Union, Tuple, Optional
class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None:
def __init__(self,
param: torch.nn.Parameter,
process_group: Optional[dist.ProcessGroup] = None,
rm_torch_payload=False) -> None:
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
if param.requires_grad and param.grad is not None:
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
@@ -19,7 +22,16 @@ class ShardedParamV2(object):
self._grad_sharded_tensor = None
# make sure the shared param is the only owner of payload
param.data = torch.empty([], dtype=param.dtype, device=param.device)
# The param.data maybe used to init the other part of the model.
# For example: File "resnet.py", line 190, in __init__
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# So we can not empty the .data at this time
self.param = param
if rm_torch_payload:
self.remove_torch_payload()
def remove_torch_payload(self):
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
@property
def data(self):