[zero] refactor model data tracing (#537)

This commit is contained in:
Jiarui Fang 2022-03-28 16:38:18 +08:00 committed by GitHub
parent a590ed0ba3
commit 705f56107c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 98 additions and 132 deletions

View File

@ -5,8 +5,6 @@ import torch.distributed as dist
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook from ._base_ophook import BaseOpHook

View File

@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
import torch import torch
from typing import Tuple
class SamplingCounter: class SamplingCounter:
@ -40,6 +41,20 @@ class MemStatsCollector:
self._start_flag = False self._start_flag = False
@property
def overall_cuda(self):
return self._overall_cuda
@property
def model_data_cuda(self):
return self._model_data_cuda
@property
def non_model_data_cuda(self):
"""Non model data stats
"""
return [(v1 - v2) for v1, v2 in zip(self._overall_cuda, self._model_data_cuda)]
def start_collection(self): def start_collection(self):
self._start_flag = True self._start_flag = True
@ -58,7 +73,7 @@ class MemStatsCollector:
self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}'))) self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
self._sampling_cnter.advance() self._sampling_cnter.advance()
def fetch_memstats(self) -> (int, int): def fetch_memstats(self) -> Tuple[int, int]:
""" """
returns cuda usage of model data and overall cuda usage. returns cuda usage of model data and overall cuda usage.
""" """

View File

@ -1,7 +1,8 @@
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
import torch import torch
from typing import Union from typing import Union, Tuple, Optional
from colossalai.logging import DistributedLogger
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int: def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
@ -12,60 +13,78 @@ def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
return target.numel() * target.element_size() return target.numel() * target.element_size()
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
"""
Trace the model memory usage.
Args:
model (torch.nn.Module): a torch model
Returns:
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
"""
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
if t is None:
return
assert isinstance(t, torch.Tensor)
_cpu_mem_usage, _cuda_mem_usage = 0, 0
if t.device.type == 'cpu':
_cpu_mem_usage += t.numel() * t.element_size()
elif t.device.type == 'cuda':
_cuda_mem_usages += t.numel() * t.element_size()
return _cuda_mem_usage, _cpu_mem_usage
cuda_mem_usage = 0
cpu_mem_usage = 0
for param in model.parameters():
if hasattr(param, 'col_attr'):
para_cuda, param_cpu = param.col_attr.get_memory_usage()
cuda_mem_usage += para_cuda
cpu_mem_usage += param_cpu
else:
t_cuda, t_cpu = _get_tensor_mem_use(param.data)
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
t_cuda, t_cpu = _get_tensor_mem_use(param.grad)
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
return cuda_mem_usage, cpu_mem_usage
class ModelDataTracer(metaclass=SingletonMeta): class ModelDataTracer(metaclass=SingletonMeta):
""" """
A tracer singleton to trace model data usage during runtime. A tracer singleton to trace model data usage during runtime.
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving. You have to register a model on the singleton first.
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
NOTE() now the class only trace cuda memory usage
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._cuda_usage = 0 self._logger = DistributedLogger("ModelDataTracer")
self._cpu_usage = 0 self._model = None
self._start_flag = False
def start(self) -> None: def _get_mem_usage(self) -> Tuple[int, int]:
self._start_flag = True """
get the memory usage of the model registered.
Returns:
Tuple[int, int]: cuda, cpu mem usage
"""
if self._model is None:
self._logger.warning("The Global ModelDataTracer is using, but no model is registered on it.")
return 0, 0
return col_model_data_mem_usage(self._model)
def close(self) -> None: def register_model(self, model) -> None:
self._start_flag = False self._model = model
def add_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
if not self._start_flag:
return
t_payload = t.payload if isinstance(t, ShardedTensor) else t
mem_use = _col_tensor_mem_usage(t_payload)
if t_payload.device.type == 'cuda':
self._cuda_usage += mem_use
elif t_payload.device.type == 'cpu':
self._cpu_usage += mem_use
else:
raise TypeError
def delete_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
if not self._start_flag:
return
t_payload = t.payload if isinstance(t, ShardedTensor) else t
mem_use = _col_tensor_mem_usage(t_payload)
if t_payload.device.type == 'cuda':
self._cuda_usage -= mem_use
elif t_payload.device.type == 'cpu':
self._cpu_usage -= mem_use
else:
raise TypeError
def clear(self) -> None:
self._cuda_usage = 0
self._cpu_usage = 0
@property @property
def cpu_usage(self): def cpu_usage(self):
return self._cpu_usage _, cpu_usage = self._get_mem_usage()
return cpu_usage
@property @property
def cuda_usage(self): def cuda_usage(self):
return self._cuda_usage cuda_usage, _ = self._get_mem_usage()
return cuda_usage
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer() GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()

View File

@ -1,5 +1,4 @@
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
import torch import torch
@ -14,7 +13,6 @@ def test_mem_collector():
collector.sample_memstats() collector.sample_memstats()
m_a = torch.randn(10).cuda() m_a = torch.randn(10).cuda()
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
b = torch.randn(10).cuda() b = torch.randn(10).cuda()
# sampling at time 1 # sampling at time 1
@ -35,8 +33,7 @@ def test_mem_collector():
cuda_use, overall_use = collector.fetch_memstats() cuda_use, overall_use = collector.fetch_memstats()
print(cuda_use, overall_use) print(cuda_use, overall_use)
print(collector._model_data_cuda) print(collector.overall_cuda)
print(collector._overall_cuda)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,7 +1,6 @@
import torch import torch
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Union from typing import Union
@ -52,9 +51,7 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
tgt_t_payload = tgt_t.data tgt_t_payload = tgt_t.data
tgt_dev = tgt_t_payload.device tgt_dev = tgt_t_payload.device
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
tgt_t_payload.copy_(src_t_payload) tgt_t_payload.copy_(src_t_payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
# remove payload of src_t # remove payload of src_t
if isinstance(src_t, ShardedTensor): if isinstance(src_t, ShardedTensor):
@ -84,11 +81,7 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
# deal with torch.device('cpu') and torch.device('cpu:0) # deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type: if t_payload.device.type == target_device.type:
return return
if use_tracer:
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
t_payload.data = t_payload.data.to(target_device) t_payload.data = t_payload.data.to(target_device)
if use_tracer:
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None: def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
@ -111,9 +104,7 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
return return
# TODO() optimize the tensor moving with non-blocking # TODO() optimize the tensor moving with non-blocking
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
t_payload.data = t_payload.data.cpu() t_payload.data = t_payload.data.cpu()
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
@ -129,5 +120,4 @@ def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device
t_payload = t.payload if isinstance(t, ShardedTensor) else t t_payload = t.payload if isinstance(t, ShardedTensor) else t
ret = t_payload.to(target_device) ret = t_payload.to(target_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(ret)
return ret return ret

View File

@ -4,8 +4,6 @@ from typing import Optional
import torch import torch
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
@ -130,7 +128,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The Callback function when entering the context The Callback function when entering the context
""" """
self.logger = get_dist_logger("ZeroInitContext") self.logger = get_dist_logger("ZeroInitContext")
GLOBAL_MODEL_DATA_TRACER.start()
def _post_context_exec(self): def _post_context_exec(self):
"""The callback function when exiting context. """The callback function when exiting context.
@ -141,12 +138,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param.col_attr.remove_torch_payload() param.col_attr.remove_torch_payload()
del self.initialized_param_list del self.initialized_param_list
GLOBAL_MODEL_DATA_TRACER.close()
model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
self.logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
def _post_init_method(self, module: torch.nn.Module): def _post_init_method(self, module: torch.nn.Module):
""" """
@ -176,9 +167,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
self.initialized_param_list.append(param) self.initialized_param_list.append(param)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor)
if self.shard_param: if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)

View File

@ -7,7 +7,6 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
class BucketTensorShardStrategy(TensorShardStrategy): class BucketTensorShardStrategy(TensorShardStrategy):
@ -18,8 +17,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
""" """
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
for t in tensor_list:
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
if len(tensor_list) == 0: if len(tensor_list) == 0:
@ -50,6 +47,3 @@ class BucketTensorShardStrategy(TensorShardStrategy):
t.reset_payload(gathered_payload) t.reset_payload(gathered_payload)
t.is_sharded = False t.is_sharded = False
offset += tensor_numels[i] offset += tensor_numels[i]
for t in tensor_list:
GLOBAL_MODEL_DATA_TRACER.add_tensor(t)

View File

@ -7,7 +7,6 @@ from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, col
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
class TensorShardStrategy(BaseShardStrategy): class TensorShardStrategy(BaseShardStrategy):
@ -36,10 +35,8 @@ class TensorShardStrategy(BaseShardStrategy):
if t.payload.device.type == 'cuda': if t.payload.device.type == 'cuda':
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}" f" but current cuda device is {get_current_device()}"
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload) t.reset_payload(sharded_payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(t.payload)
t.is_sharded = True t.is_sharded = True
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
@ -56,10 +53,8 @@ class TensorShardStrategy(BaseShardStrategy):
else: else:
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device())) buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False) dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape) gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload) t.reset_payload(gathered_payload)
colo_model_data_tensor_move_inline(t, target_device, use_tracer=False) colo_model_data_tensor_move_inline(t, target_device, use_tracer=False)
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
t.is_sharded = False t.is_sharded = False

View File

@ -11,6 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
@ -83,6 +84,7 @@ class ShardedModelV2(nn.Module):
# Init Memory Statistics Collector # Init Memory Statistics Collector
self._use_memory_tracer = use_memory_tracer self._use_memory_tracer = use_memory_tracer
if self._use_memory_tracer: if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_model(self)
self._memstats_collector = MemStatsCollector() self._memstats_collector = MemStatsCollector()
else: else:
self._memstats_collector = None self._memstats_collector = None
@ -147,14 +149,16 @@ class ShardedModelV2(nn.Module):
def _update_memstats(self): def _update_memstats(self):
if self._iter_cnter == 0 and self._memstats_collector: if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection() self._memstats_collector.finish_collection()
self.logger.info(f'model data cuda, {self._memstats_collector.model_data_cuda}')
self.logger.info(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}')
if self._memstats_collector: if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter() self._memstats_collector.reset_sampling_cnter()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that # the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training. # model data is fixed in cuda during training.
# cuda margin space can be used to store OS. # cuda margin space can be used to store OS.
self._cuda_margin_space = colo_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda) self._cuda_margin_space = colo_cuda_memory_capacity() - max(self._memstats_collector.overall_cuda)
self._iter_cnter += 1 self._iter_cnter += 1
@torch.no_grad() @torch.no_grad()

View File

@ -9,7 +9,6 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from torch import Tensor from torch import Tensor
@ -218,9 +217,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None # We must set grad to None
# Because we will judge whether local grad accumulation # Because we will judge whether local grad accumulation
# is enabled by wheter grad is None # is enabled by wheter grad is None
for group in self.param_groups:
for p in group['params']:
GLOBAL_MODEL_DATA_TRACER.delete_tensor(p.grad)
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
def sync_grad(self): def sync_grad(self):

View File

@ -1,4 +1,3 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
@ -13,22 +12,15 @@ import torch.multiprocessing as mp
def run_tensor_move(rank): def run_tensor_move(rank):
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
GLOBAL_MODEL_DATA_TRACER.start()
src_t = torch.ones(2, 3).cuda() src_t = torch.ones(2, 3).cuda()
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24)
tgt_t = torch.zeros(2, 3) tgt_t = torch.zeros(2, 3)
colo_model_data_tensor_move(src_t, tgt_t) colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
src_t = torch.ones(2, 3) src_t = torch.ones(2, 3)
tgt_t = torch.zeros(2, 3).cuda().half() tgt_t = torch.zeros(2, 3).cuda().half()
colo_model_data_tensor_move(src_t, tgt_t) colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
# the src_t has been removed # the src_t has been removed
assert (src_t.numel() == 0) assert (src_t.numel() == 0)
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
@ -36,15 +28,11 @@ def run_tensor_move(rank):
src_t = ShardedTensor(torch.ones(2, 3)) src_t = ShardedTensor(torch.ones(2, 3))
tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half()) tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half())
colo_model_data_tensor_move(src_t, tgt_t) colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
assert (tgt_t.device.type == 'cuda') assert (tgt_t.device.type == 'cuda')
colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu')) colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu'))
assert (tgt_t.device.type == 'cpu') assert (tgt_t.device.type == 'cpu')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
GLOBAL_MODEL_DATA_TRACER.close()
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")

View File

@ -1,52 +1,28 @@
import pytest import pytest
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
import colossalai import colossalai
import torch import torch
from functools import partial from functools import partial
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port
def _run_colo_model_data_tensor_move_inline(): def _run_colo_model_data_tensor_move_inline():
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
GLOBAL_MODEL_DATA_TRACER.start()
for t in [torch.randn(2, 3), ShardedTensor(torch.randn(2, 3))]: for t in [torch.randn(2, 3), ShardedTensor(torch.randn(2, 3))]:
GLOBAL_MODEL_DATA_TRACER.add_tensor(t)
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0
colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}")) colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}"))
assert t.device == torch.device(f"cuda:{get_current_device()}") assert t.device == torch.device(f"cuda:{get_current_device()}")
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4
GLOBAL_MODEL_DATA_TRACER.clear()
GLOBAL_MODEL_DATA_TRACER.close()
def _run_colo_model_data_tensor_move(): def _run_colo_model_data_tensor_move():
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
GLOBAL_MODEL_DATA_TRACER.start()
for t in [(torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device())), for t in [(torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device())),
(ShardedTensor(torch.ones(2, 3)), ShardedTensor(torch.zeros(2, 3).cuda(get_current_device())))]: (ShardedTensor(torch.ones(2, 3)), ShardedTensor(torch.zeros(2, 3).cuda(get_current_device())))]:
cpu_t, cuda_t = t cpu_t, cuda_t = t
GLOBAL_MODEL_DATA_TRACER.add_tensor(cpu_t)
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0
colo_model_data_tensor_move(cpu_t, cuda_t) colo_model_data_tensor_move(cpu_t, cuda_t)
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4
GLOBAL_MODEL_DATA_TRACER.clear()
GLOBAL_MODEL_DATA_TRACER.close()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

View File

@ -10,19 +10,21 @@ import torch.multiprocessing as mp
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_tracer.model_data_memtracer import col_model_data_mem_usage
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.logging import get_dist_logger
from common import CONFIG from common import CONFIG
@parameterize("init_device_type", ['cpu', 'cuda']) @parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(init_device_type, shard_strategy_class): def run_model_test(init_device_type, shard_strategy_class):
logger = get_dist_logger("test_zero_init")
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func() model_builder, _, _, _, _ = get_components_func()
model_numel_tensor = torch.zeros(1, dtype=torch.int) model_numel_tensor = torch.zeros(1, dtype=torch.int)
@ -32,6 +34,8 @@ def run_model_test(init_device_type, shard_strategy_class):
init_device = torch.device("cpu") init_device = torch.device("cpu")
else: else:
continue continue
model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=init_device, target_device=init_device,
shard_strategy=shard_strategy_class(), shard_strategy=shard_strategy_class(),
@ -46,11 +50,13 @@ def run_model_test(init_device_type, shard_strategy_class):
assert param.col_attr.sharded_data_tensor.is_sharded assert param.col_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) cuda_mem_use, cpu_mem_use = col_model_data_mem_usage(model)
else: model_data_cuda_mem_MB = cuda_mem_use / 1e6
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0) logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
GLOBAL_MODEL_DATA_TRACER.clear() sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):