mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -4,6 +4,11 @@ from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy
|
||||
|
||||
__all__ = [
|
||||
'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy',
|
||||
'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook'
|
||||
"StatefulTensorMgr",
|
||||
"StatefulTensor",
|
||||
"CPUTensorPlacementPolicy",
|
||||
"CUDATensorPlacementPolicy",
|
||||
"AutoTensorPlacementPolicy",
|
||||
"register_ophooks_recursively",
|
||||
"BaseOpHook",
|
||||
]
|
||||
|
@@ -2,16 +2,15 @@ from enum import EnumMeta
|
||||
|
||||
|
||||
class GeminiMemoryManager(object):
|
||||
|
||||
def __init__(self, states_cls: EnumMeta):
|
||||
super().__init__()
|
||||
self.states_cls = states_cls
|
||||
self._cnter = 0 # the counter of instances
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem = dict()
|
||||
self.state_mem = dict()
|
||||
self.state_mem['cpu'] = dict()
|
||||
self.state_mem['cuda'] = dict()
|
||||
self.state_mem["cpu"] = dict()
|
||||
self.state_mem["cuda"] = dict()
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -20,15 +19,15 @@ class GeminiMemoryManager(object):
|
||||
return self._cnter
|
||||
|
||||
def reset(self):
|
||||
self._cnter = 0 # the counter of instances
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
||||
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
||||
self.total_mem["cpu"] = 0 # memory occupation of instances in cpu
|
||||
self.total_mem["cuda"] = 0 # memory of occupation of instances in cuda
|
||||
|
||||
# memory conditions for all states
|
||||
for state in self.states_cls:
|
||||
self.state_mem['cpu'][state] = 0
|
||||
self.state_mem['cuda'][state] = 0
|
||||
self.state_mem["cpu"][state] = 0
|
||||
self.state_mem["cuda"][state] = 0
|
||||
|
||||
def register_new_instance(self):
|
||||
self._cnter += 1
|
||||
@@ -37,12 +36,16 @@ class GeminiMemoryManager(object):
|
||||
self._cnter -= 1
|
||||
|
||||
def print_info(self):
|
||||
print(f"Total number: {self.total_number}",
|
||||
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
|
||||
sep='\n')
|
||||
print(
|
||||
f"Total number: {self.total_number}",
|
||||
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
|
||||
sep="\n",
|
||||
)
|
||||
|
||||
for state in self.states_cls:
|
||||
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
|
||||
sep='\n')
|
||||
print(
|
||||
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
|
||||
sep="\n",
|
||||
)
|
||||
|
@@ -22,7 +22,7 @@ class ShardGradMemTracerHook(BaseOpHook):
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, '_sharded_grad')
|
||||
assert hasattr(param, "_sharded_grad")
|
||||
param._sharded_grad.setup()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
|
@@ -19,25 +19,25 @@ class ShardParamHook(BaseOpHook):
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert hasattr(param, "ca_attr")
|
||||
param.ca_attr.gather()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert hasattr(param, "ca_attr")
|
||||
param.ca_attr.shard()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert hasattr(param, "ca_attr")
|
||||
param.ca_attr.gather()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert hasattr(param, "ca_attr")
|
||||
param.ca_attr.shard()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
|
@@ -15,8 +15,7 @@ class TrainingPhase(Enum):
|
||||
BACKWARD = 1
|
||||
|
||||
|
||||
class GradMemStats():
|
||||
|
||||
class GradMemStats:
|
||||
def __init__(self) -> None:
|
||||
self.unreleased_grad_flag = {}
|
||||
self.unreleased_grad_volume = 0
|
||||
@@ -26,8 +25,7 @@ class GradMemStats():
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
|
||||
class GradMemTracerHook():
|
||||
|
||||
class GradMemTracerHook:
|
||||
def __init__(self, grad_stats: GradMemStats):
|
||||
self.grad_hook_list = []
|
||||
self._grad_stats = grad_stats
|
||||
@@ -50,7 +48,6 @@ class GradMemTracerHook():
|
||||
|
||||
|
||||
class ParamMemTracerHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None:
|
||||
super().__init__()
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
@@ -79,10 +76,9 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||
if cur_dev == "cpu":
|
||||
if p.grad is not None and p.grad.device.type == "cpu":
|
||||
raise NotImplementedError("Only run in forward propagation")
|
||||
p.data = torch.empty(p.data.shape,
|
||||
device="cuda",
|
||||
dtype=p.data.dtype,
|
||||
requires_grad=p.data.requires_grad)
|
||||
p.data = torch.empty(
|
||||
p.data.shape, device="cuda", dtype=p.data.dtype, requires_grad=p.data.requires_grad
|
||||
)
|
||||
elif cur_dev == "cuda":
|
||||
alloc_storage(p.data)
|
||||
|
||||
|
@@ -48,7 +48,6 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
@@ -64,7 +63,6 @@ class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
@@ -84,16 +82,15 @@ class PostBackwardFunction(torch.autograd.Function):
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook],
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
def register_ophooks_recursively(
|
||||
module: torch.nn.Module, ophook_list: List[BaseOpHook], name: str = "", filter_fn: Optional[Callable] = None
|
||||
):
|
||||
r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert isinstance(ophook_list, (list, tuple))
|
||||
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
||||
assert len(ophook_list) > 0, "expected at least 1 hook in the argument ophook_list but found 0"
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
assert isinstance(hook, BaseOpHook)
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
@@ -118,7 +115,6 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
@@ -127,7 +123,6 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
|
@@ -5,7 +5,6 @@ import torch
|
||||
|
||||
|
||||
class BaseParamHookMgr(object):
|
||||
|
||||
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
|
||||
r"""
|
||||
register backward hook on every parameters of module
|
||||
@@ -23,9 +22,9 @@ class BaseParamHookMgr(object):
|
||||
```
|
||||
"""
|
||||
if not torch.is_grad_enabled():
|
||||
return # don't register grad hooks if grad isn't enabled
|
||||
return # don't register grad hooks if grad isn't enabled
|
||||
for p in self._param_list:
|
||||
if p.requires_grad and not hasattr(p, '_base_param_hook'):
|
||||
if p.requires_grad and not hasattr(p, "_base_param_hook"):
|
||||
handle = p.register_hook(functools.partial(hook_call, p))
|
||||
p._base_param_hook = handle
|
||||
|
||||
@@ -35,5 +34,5 @@ class BaseParamHookMgr(object):
|
||||
"""
|
||||
|
||||
for p in self._param_list:
|
||||
if p.requires_grad and hasattr(p, '_base_param_hook'):
|
||||
if p.requires_grad and hasattr(p, "_base_param_hook"):
|
||||
p._base_param_hook.remove()
|
||||
|
@@ -25,13 +25,14 @@ class StatefulTensor(object):
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
# Global Stateful Tensor Manager
|
||||
GST_MGR = GeminiMemoryManager(TensorState)
|
||||
|
||||
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = None
|
||||
self._payload_size = 0 # byte size of current payload
|
||||
self._payload_size = 0 # byte size of current payload
|
||||
|
||||
StatefulTensor.GST_MGR.register_new_instance()
|
||||
|
||||
@@ -47,7 +48,7 @@ class StatefulTensor(object):
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
return 0 # if a tensor has no storage, 0 should be returned
|
||||
return 0 # if a tensor has no storage, 0 should be returned
|
||||
return self._payload.data_ptr()
|
||||
|
||||
def set_null(self) -> None:
|
||||
@@ -80,7 +81,7 @@ class StatefulTensor(object):
|
||||
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
|
||||
|
||||
if not isinstance(device, torch.device):
|
||||
to_device = torch.device('cuda', device)
|
||||
to_device = torch.device("cuda", device)
|
||||
else:
|
||||
to_device = device
|
||||
|
||||
@@ -97,7 +98,6 @@ class StatefulTensor(object):
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def payload_reset(self, tensor) -> None:
|
||||
|
||||
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
|
||||
|
||||
if self.payload is not None:
|
||||
@@ -168,8 +168,7 @@ class StatefulTensor(object):
|
||||
self._payload_size = 0
|
||||
|
||||
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
|
||||
"""Update global manager when changing the state of a tensor
|
||||
"""
|
||||
"""Update global manager when changing the state of a tensor"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
device_type = self.device.type
|
||||
@@ -189,8 +188,7 @@ class StatefulTensor(object):
|
||||
manager.total_mem[device_type] -= size
|
||||
|
||||
def __trans_device_update(self, from_type: str, to_type: str):
|
||||
"""Update global manager when changing the device of a tensor
|
||||
"""
|
||||
"""Update global manager when changing the device of a tensor"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
state = self.state
|
||||
|
@@ -3,14 +3,11 @@ import types
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from .stateful_tensor import StatefulTensor, TensorState
|
||||
from .tensor_placement_policy import TensorPlacementPolicy
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline
|
||||
|
||||
|
||||
class StatefulTensorMgr(object):
|
||||
@@ -44,8 +41,7 @@ class StatefulTensorMgr(object):
|
||||
pass
|
||||
|
||||
def finish_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
"""This function must be called when each iteration finishes"""
|
||||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
self._cpu_gpu_move_volume = 0
|
||||
@@ -53,19 +49,21 @@ class StatefulTensorMgr(object):
|
||||
self._evict_time = 0
|
||||
|
||||
def adjust_layout(self) -> None:
|
||||
""" Adjust the layout of stateful tensor according to the information provided
|
||||
"""Adjust the layout of stateful tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE]
|
||||
cuda_demand = StatefulTensor.GST_MGR.state_mem["cpu"][TensorState.COMPUTE]
|
||||
start = time()
|
||||
move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup)
|
||||
self._layout_time += time() - start
|
||||
vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
vol, evict_time = self._tensor_placement_policy.evict_tensors(
|
||||
hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx,
|
||||
)
|
||||
self._cpu_gpu_move_volume += vol
|
||||
self._evict_time += evict_time
|
||||
# move COMPUTE tensors to CUDA
|
||||
@@ -92,10 +90,10 @@ class StatefulTensorMgr(object):
|
||||
if tensor.state == TensorState.FREE:
|
||||
continue
|
||||
|
||||
if tensor.device.type == 'cuda':
|
||||
if tensor.device.type == "cuda":
|
||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||
hold_cuda_tensor_list.append(tensor)
|
||||
elif tensor.device.type == 'cpu':
|
||||
elif tensor.device.type == "cpu":
|
||||
if tensor.state == TensorState.COMPUTE:
|
||||
move_to_cuda_tensor_list.append(tensor)
|
||||
else:
|
||||
|
@@ -10,11 +10,10 @@ from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
|
||||
from .stateful_tensor import StatefulTensor
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline
|
||||
|
||||
|
||||
class TensorPlacementPolicy(ABC):
|
||||
|
||||
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
self.device: Optional[torch.device] = device
|
||||
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
|
||||
@@ -25,9 +24,8 @@ class TensorPlacementPolicy(ABC):
|
||||
|
||||
|
||||
class CPUTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
|
||||
super().__init__(torch.device("cpu"), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||
volume = 0
|
||||
@@ -38,9 +36,8 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
|
||||
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available"
|
||||
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||
@@ -48,7 +45,6 @@ class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
|
||||
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
super().__init__(None, mem_stats_collector=mem_stats_collector)
|
||||
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
@@ -56,13 +52,15 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
self._warmup_non_model_data_ratio: float = 0.8
|
||||
self._steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def evict_tensors(self,
|
||||
hold_cuda_tensor_list: List[StatefulTensor],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: List[StatefulTensor] = [],
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> int:
|
||||
def evict_tensors(
|
||||
self,
|
||||
hold_cuda_tensor_list: List[StatefulTensor],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: List[StatefulTensor] = [],
|
||||
compute_idx: int = 0,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""
|
||||
Evict tensors from CUDA device.
|
||||
|
||||
@@ -81,13 +79,13 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"]
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda")
|
||||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
@@ -99,15 +97,16 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_tensor_list = hold_cuda_tensor_list
|
||||
if not warmup:
|
||||
to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx,
|
||||
tuple(compute_list))
|
||||
to_free_tensor_list = self._sort_hold_cuda_tensors(
|
||||
tuple(hold_cuda_tensor_list), compute_idx, tuple(compute_list)
|
||||
)
|
||||
# print(self._sort_hold_cuda_tensors.cache_info())
|
||||
end = time()
|
||||
for t in to_free_tensor_list:
|
||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += t.payload_size
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
colo_model_data_tensor_move_inline(t, torch.device("cpu"))
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
@@ -126,14 +125,13 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
|
||||
class TensorPlacementPolicyFactory:
|
||||
|
||||
@staticmethod
|
||||
def create(policy_name: str) -> Type[TensorPlacementPolicy]:
|
||||
if policy_name == 'cpu':
|
||||
if policy_name == "cpu":
|
||||
return CPUTensorPlacementPolicy
|
||||
elif policy_name == 'cuda':
|
||||
elif policy_name == "cuda":
|
||||
return CUDATensorPlacementPolicy
|
||||
elif policy_name == 'auto':
|
||||
elif policy_name == "auto":
|
||||
return AutoTensorPlacementPolicy
|
||||
else:
|
||||
raise TypeError(f"Unknown tensor placement policy {policy_name}")
|
||||
|
@@ -30,16 +30,17 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[
|
||||
cuda_use, cpu_use = 0, 0
|
||||
|
||||
mem_use = t.storage().size() * t.element_size()
|
||||
if t.device.type == 'cuda':
|
||||
if t.device.type == "cuda":
|
||||
cuda_use += mem_use
|
||||
elif t.device.type == 'cpu':
|
||||
elif t.device.type == "cpu":
|
||||
cpu_use += mem_use
|
||||
|
||||
return cuda_use, cpu_use
|
||||
|
||||
|
||||
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
|
||||
torch.Tensor]) -> None:
|
||||
def colo_model_data_tensor_move(
|
||||
src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, torch.Tensor]
|
||||
) -> None:
|
||||
"""
|
||||
A colossal API for model data tensor move.
|
||||
The src and target tensors could be resident on both CPU and GPU.
|
||||
@@ -71,8 +72,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
|
||||
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
|
||||
|
||||
|
||||
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
|
||||
int]) -> None:
|
||||
def colo_model_data_tensor_move_inline(
|
||||
t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, int]
|
||||
) -> None:
|
||||
"""
|
||||
move a tensor to the target_device
|
||||
Args:
|
||||
@@ -80,14 +82,14 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
|
||||
target_device: a target device, if type is int, it the index of cuda card.
|
||||
"""
|
||||
if not isinstance(target_device, torch.device):
|
||||
target_device = torch.device(f'cuda:{target_device}')
|
||||
target_device = torch.device(f"cuda:{target_device}")
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
t.data = t.data.to(target_device)
|
||||
elif isinstance(t, StatefulTensor):
|
||||
t.move_to(target_device)
|
||||
else:
|
||||
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
|
||||
raise TypeError(f"colo_model_data_tensor_move_inline dose not accept type {type(t)}")
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
|
||||
@@ -100,9 +102,9 @@ def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
|
||||
if isinstance(t, torch.Tensor):
|
||||
t.data = t.data.cpu()
|
||||
elif isinstance(t, StatefulTensor):
|
||||
t.move_to(torch.device('cpu'))
|
||||
t.move_to(torch.device("cpu"))
|
||||
else:
|
||||
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
raise TypeError(f"colo_model_data_move_to_cpu dose not accept type {type(t)}")
|
||||
|
||||
|
||||
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||
|
Reference in New Issue
Block a user