mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[zero] refactor model data tracing (#522)
This commit is contained in:
@@ -22,6 +22,7 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
self._cpu_usage = 0
|
||||
self._start_flag = False
|
||||
|
||||
def start(self) -> None:
|
||||
@@ -30,22 +31,33 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
||||
def close(self) -> None:
|
||||
self._start_flag = False
|
||||
|
||||
def add_tensor(self, t: torch.Tensor) -> None:
|
||||
def add_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
|
||||
if not self._start_flag:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
|
||||
mem_use = _col_tensor_mem_usage(t)
|
||||
self._cuda_usage += mem_use
|
||||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
mem_use = _col_tensor_mem_usage(t_payload)
|
||||
if t_payload.device.type == 'cuda':
|
||||
self._cuda_usage += mem_use
|
||||
elif t_payload.device.type == 'cpu':
|
||||
self._cpu_usage += mem_use
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def delete_tensor(self, t: torch.Tensor) -> None:
|
||||
def delete_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
|
||||
if not self._start_flag:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
|
||||
mem_use = _col_tensor_mem_usage(t)
|
||||
self._cuda_usage -= mem_use
|
||||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
mem_use = _col_tensor_mem_usage(t_payload)
|
||||
if t_payload.device.type == 'cuda':
|
||||
self._cuda_usage -= mem_use
|
||||
elif t_payload.device.type == 'cpu':
|
||||
self._cpu_usage -= mem_use
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def clear(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
self._cpu_usage = 0
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
|
||||
@@ -3,7 +3,7 @@ from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
from typing import Union, Optional
|
||||
from typing import Union
|
||||
|
||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||
|
||||
@@ -52,11 +52,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
|
||||
tgt_t_payload = tgt_t.data
|
||||
tgt_dev = tgt_t_payload.device
|
||||
|
||||
if src_dev.type == 'cuda' and tgt_dev.type == 'cpu':
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
|
||||
elif src_dev.type == 'cpu' and tgt_dev.type == 'cuda':
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
|
||||
tgt_t_payload.copy_(src_t_payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
|
||||
|
||||
# remove payload of src_t
|
||||
if isinstance(src_t, ShardedTensor):
|
||||
@@ -65,7 +63,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
|
||||
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
|
||||
|
||||
|
||||
def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> None:
|
||||
def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
|
||||
target_device: torch.device,
|
||||
use_tracer: bool = True) -> None:
|
||||
"""
|
||||
move a tensor to the target_device
|
||||
Args:
|
||||
@@ -84,13 +84,11 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], ta
|
||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||
if t_payload.device.type == target_device.type:
|
||||
return
|
||||
|
||||
if target_device.type == 'cuda':
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
elif target_device.type == 'cpu':
|
||||
if use_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||
|
||||
t_payload.data = t_payload.data.to(target_device)
|
||||
if use_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||
@@ -115,3 +113,4 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||
# TODO() optimize the tensor moving with non-blocking
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||
t_payload.data = t_payload.data.cpu()
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
|
||||
Reference in New Issue
Block a user