mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
4
colossalai/legacy/zero/sharded_param/__init__.py
Normal file
4
colossalai/legacy/zero/sharded_param/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .sharded_param import ShardedParamV2
|
||||
from .sharded_tensor import ShardedTensor
|
||||
|
||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
110
colossalai/legacy/zero/sharded_param/sharded_param.py
Normal file
110
colossalai/legacy/zero/sharded_param/sharded_param.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.legacy.zero.gemini.tensor_utils import colo_tensor_mem_usage
|
||||
|
||||
from .sharded_tensor import ShardedTensor
|
||||
|
||||
EMPTY_TENSOR_DICT = {}
|
||||
|
||||
|
||||
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
|
||||
key = (device, dtype)
|
||||
if key not in EMPTY_TENSOR_DICT:
|
||||
EMPTY_TENSOR_DICT[key] = torch.empty(0, dtype=dtype, device=device)
|
||||
|
||||
return EMPTY_TENSOR_DICT[key]
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
# This attribute must be initialized in ShardedModel
|
||||
self.offload_grad: bool = False
|
||||
|
||||
# make sure the shared param is the only owner of payload
|
||||
# The param.data maybe used to init the other part of the model.
|
||||
# For example: File "resnet.py", line 190, in __init__
|
||||
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
# So we can not empty the .data at this time
|
||||
self.param = param
|
||||
if set_data_none:
|
||||
self.set_data_none()
|
||||
|
||||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||
"""returns stateful tensors kept by this class.
|
||||
"""
|
||||
return [self._sharded_data_tensor]
|
||||
|
||||
def set_data_none(self):
|
||||
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)
|
||||
|
||||
def set_grad_none(self):
|
||||
self.saved_grad.set_null()
|
||||
|
||||
@property
|
||||
def sharded_data_tensor(self):
|
||||
return self._sharded_data_tensor
|
||||
|
||||
@property
|
||||
def data_payload(self):
|
||||
assert not self.sharded_data_tensor.is_null()
|
||||
return self.sharded_data_tensor.payload
|
||||
|
||||
@property
|
||||
def grad_payload(self):
|
||||
assert not self.saved_grad.is_null()
|
||||
return self.saved_grad.payload
|
||||
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self.sharded_data_tensor.is_sharded
|
||||
|
||||
def data_payload_reset(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.sharded_data_tensor.payload_reset(tensor)
|
||||
|
||||
def grad_payload_reset(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.saved_grad.payload_reset(tensor)
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
get the memory usage of the param, including data and grad
|
||||
Returns:
|
||||
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
|
||||
"""
|
||||
cuda_mem_use, cpu_mem_use = 0, 0
|
||||
|
||||
def _update_mem_use(t: Optional[torch.Tensor]):
|
||||
if t is None:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor)
|
||||
nonlocal cuda_mem_use
|
||||
nonlocal cpu_mem_use
|
||||
t_cuda, t_cpu = colo_tensor_mem_usage(t)
|
||||
cuda_mem_use += t_cuda
|
||||
cpu_mem_use += t_cpu
|
||||
|
||||
address_set = set()
|
||||
_update_mem_use(self.data_payload)
|
||||
address_set.add(self.data_payload.data_ptr())
|
||||
|
||||
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.grad_payload)
|
||||
address_set.add(self.saved_grad.data_ptr())
|
||||
|
||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.data)
|
||||
address_set.add(self.param.data.data_ptr())
|
||||
|
||||
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.grad)
|
||||
|
||||
return cuda_mem_use, cpu_mem_use
|
40
colossalai/legacy/zero/sharded_param/sharded_tensor.py
Normal file
40
colossalai/legacy/zero/sharded_param/sharded_tensor.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedTensor(StatefulTensor):
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||||
r"""
|
||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||
"""
|
||||
assert tensor.requires_grad is False
|
||||
super().__init__(tensor, state)
|
||||
|
||||
# kept the shape, numel and dtype of the init tensor.
|
||||
self._origin_shape = tensor.shape
|
||||
self._origin_numel = tensor.numel()
|
||||
self._origin_dtype = tensor.dtype
|
||||
self._is_sharded = False
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def origin_numel(self) -> int:
|
||||
return self._origin_numel
|
||||
|
||||
@property
|
||||
def origin_shape(self) -> int:
|
||||
return self._origin_shape
|
||||
|
||||
@property
|
||||
def is_sharded(self):
|
||||
return self._is_sharded
|
||||
|
||||
@is_sharded.setter
|
||||
def is_sharded(self, flag: bool):
|
||||
self._is_sharded = flag
|
Reference in New Issue
Block a user