mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .sharded_model_v2 import ShardedModelV2
|
||||
|
||||
__all__ = ['ShardedModelV2']
|
||||
__all__ = ["ShardedModelV2"]
|
||||
|
@@ -25,7 +25,7 @@ def free_storage(data: torch.Tensor) -> None:
|
||||
@torch.no_grad()
|
||||
def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
|
||||
"""Allocate storage for a tensor."""
|
||||
if data.storage().size() == size.numel(): # no need to reallocate
|
||||
if data.storage().size() == size.numel(): # no need to reallocate
|
||||
return
|
||||
assert data.storage().size() == 0
|
||||
data.storage().resize_(size.numel())
|
||||
|
@@ -20,7 +20,6 @@ else:
|
||||
|
||||
|
||||
class Bucket:
|
||||
|
||||
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||||
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
|
||||
self.group = group
|
||||
@@ -35,18 +34,18 @@ class Bucket:
|
||||
return
|
||||
# reduce-scatter bucket
|
||||
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||||
dist._reduce_scatter_base(self.output_shard[:self.offset],
|
||||
self.buffer[:, :self.offset].contiguous(),
|
||||
group=self.group)
|
||||
dist._reduce_scatter_base(
|
||||
self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
|
||||
)
|
||||
else:
|
||||
dist.reduce_scatter(self.output_shard[:self.offset],
|
||||
list(self.buffer[:, :self.offset].unbind(0)),
|
||||
group=self.group)
|
||||
dist.reduce_scatter(
|
||||
self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
|
||||
)
|
||||
# execute post-reduction callbacks
|
||||
for callback_fn in self.callbacks:
|
||||
callback_fn()
|
||||
# reuse input bucket but allocate a fresh output shard
|
||||
self.buffer[:, :self.offset].zero_()
|
||||
self.buffer[:, : self.offset].zero_()
|
||||
self.offset = 0
|
||||
self.callbacks.clear()
|
||||
self.output_shard = torch.zeros_like(self.buffer[0])
|
||||
@@ -74,12 +73,12 @@ class Bucket:
|
||||
tensor_size = tensor_list[0].numel()
|
||||
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
|
||||
offset = self.offset
|
||||
self.buffer[:, offset:offset + tensor_size].copy_(stacked_input)
|
||||
self.buffer[:, offset : offset + tensor_size].copy_(stacked_input)
|
||||
self.offset += tensor_size
|
||||
|
||||
# callback will be given the reduced result
|
||||
if callback_fn is not None:
|
||||
result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0])
|
||||
result_view = self.output_shard[offset : offset + tensor_size].view_as(tensor_list[0])
|
||||
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||||
|
||||
|
||||
@@ -142,8 +141,9 @@ class ReduceScatterBucketer:
|
||||
"""
|
||||
world_size = group.size()
|
||||
|
||||
assert (len(input_list) == world_size
|
||||
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
|
||||
assert (
|
||||
len(input_list) == world_size
|
||||
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
|
||||
|
||||
first_input = input_list[0]
|
||||
first_input_size = first_input.numel()
|
||||
@@ -183,7 +183,7 @@ class ReduceScatterBucketer:
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
|
||||
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||
return 0
|
||||
MB = 1024 * 1024
|
||||
bucket_size = self.bucket_size_mb * MB / element_size
|
||||
|
@@ -2,7 +2,6 @@
|
||||
import functools
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -40,7 +39,7 @@ from .zero_hook import ZeroHook
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
@@ -78,20 +77,22 @@ class ShardedModelV2(nn.Module):
|
||||
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False,
|
||||
bf16: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
tensor_placement_policy: str = "cuda",
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False,
|
||||
bf16: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
assert not isinstance(module, ShardedModelV2), "Nested ShardedModelV2 is not supported."
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
self.bf16 = bf16
|
||||
@@ -101,13 +102,13 @@ class ShardedModelV2(nn.Module):
|
||||
sharded_cnt = 0
|
||||
unshard_cnt = 0
|
||||
for param in submodule.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||
assert hasattr(param, "colo_attr"), "You must use ZeroInitContext to init your module first."
|
||||
if param.colo_attr.param_is_sharded:
|
||||
sharded_cnt += 1
|
||||
else:
|
||||
unshard_cnt += 1
|
||||
assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param'
|
||||
submodule.param_is_sharded = (sharded_cnt > 0)
|
||||
assert (not sharded_cnt) or (not unshard_cnt), "nn.Module can not both have shard param and unshard param"
|
||||
submodule.param_is_sharded = sharded_cnt > 0
|
||||
|
||||
self.sharded_params = []
|
||||
self.unshard_params = []
|
||||
@@ -124,7 +125,7 @@ class ShardedModelV2(nn.Module):
|
||||
self.rank = dist.get_rank(self.process_group)
|
||||
self.shard_strategy = shard_strategy
|
||||
|
||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||
self._use_memory_tracer = tensor_placement_policy == "auto"
|
||||
if self._use_memory_tracer:
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
@@ -132,18 +133,19 @@ class ShardedModelV2(nn.Module):
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
||||
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
||||
tensor_placement_policy
|
||||
)(mem_stats_collector=self._memstats_collector)
|
||||
|
||||
if 'warmup_non_model_data_ratio' in kwargs:
|
||||
if tensor_placement_policy != 'auto':
|
||||
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
|
||||
if "warmup_non_model_data_ratio" in kwargs:
|
||||
if tensor_placement_policy != "auto":
|
||||
self.logger.warning("setting warmup_non_model_data_ratio is useless if not use auto placement")
|
||||
else:
|
||||
ratio = kwargs['warmup_non_model_data_ratio']
|
||||
ratio = kwargs["warmup_non_model_data_ratio"]
|
||||
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
|
||||
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
|
||||
self.logger.info(f"setting warmup_non_model_data_ratio as {ratio} for auto placement")
|
||||
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
|
||||
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, "colo_attr")]
|
||||
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
|
||||
|
||||
# Register hooks
|
||||
@@ -155,7 +157,7 @@ class ShardedModelV2(nn.Module):
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
self._cpu_offload: bool = tensor_placement_policy != 'cuda'
|
||||
self._cpu_offload: bool = tensor_placement_policy != "cuda"
|
||||
for param in module.parameters():
|
||||
# Init `offload_grad`
|
||||
param.colo_attr.offload_grad = self._cpu_offload
|
||||
@@ -164,9 +166,11 @@ class ShardedModelV2(nn.Module):
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
# However, if you set gradient_predivide_factor to None, we will set
|
||||
# gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if \
|
||||
gradient_predivide_factor is not None else \
|
||||
get_gradient_predivide_factor(self.world_size)
|
||||
self.gradient_predivide_factor: float = (
|
||||
gradient_predivide_factor
|
||||
if gradient_predivide_factor is not None
|
||||
else get_gradient_predivide_factor(self.world_size)
|
||||
)
|
||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
@@ -194,7 +198,7 @@ class ShardedModelV2(nn.Module):
|
||||
def cpu_offload(self):
|
||||
return self._cpu_offload
|
||||
|
||||
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
|
||||
def dump_memory_stats(self, filename: Optional[str] = "dump_mem_stats.log") -> None:
|
||||
"""
|
||||
dummy memory tracer collected information to a file.
|
||||
try:
|
||||
@@ -205,18 +209,18 @@ class ShardedModelV2(nn.Module):
|
||||
exit(0)
|
||||
"""
|
||||
if self._use_memory_tracer:
|
||||
self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0])
|
||||
self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
|
||||
if gpc.get_global_rank() == 0:
|
||||
with open(filename, 'w+') as f:
|
||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||
f.write('CUDA model data (GB)\n')
|
||||
f.write('\n')
|
||||
f.write('CUDA non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
||||
f.write('CPU non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
|
||||
f.write('\n')
|
||||
with open(filename, "w+") as f:
|
||||
f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n")
|
||||
f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n")
|
||||
f.write("CUDA model data (GB)\n")
|
||||
f.write("\n")
|
||||
f.write("CUDA non model data (GB)\n")
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list("cuda")))
|
||||
f.write("CPU non model data (GB)\n")
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list("cpu")))
|
||||
f.write("\n")
|
||||
|
||||
def _pre_forward_operations(self, *args):
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
@@ -224,14 +228,14 @@ class ShardedModelV2(nn.Module):
|
||||
self._start_collect_memstats()
|
||||
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'colo_attr'):
|
||||
if hasattr(p, "colo_attr"):
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
self._stateful_tensor_mgr.start_iter()
|
||||
|
||||
def _post_forward_operations(self):
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'colo_attr'):
|
||||
if hasattr(p, "colo_attr"):
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
@@ -261,8 +265,9 @@ class ShardedModelV2(nn.Module):
|
||||
# the way to calculate margin space is based on the assumption that
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_device_memory_capacity(
|
||||
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
|
||||
self._cuda_margin_space = (
|
||||
colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
@@ -330,7 +335,7 @@ class ShardedModelV2(nn.Module):
|
||||
"""
|
||||
if grad is None:
|
||||
return
|
||||
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
||||
assert not grad.requires_grad, "ShardedModel only works with gradients that don't require gradients"
|
||||
if not self._require_backward_grad_sync:
|
||||
return
|
||||
# used to cheat Pytorch, since we can't return None
|
||||
@@ -354,16 +359,19 @@ class ShardedModelV2(nn.Module):
|
||||
grad.data.div_(self.gradient_predivide_factor)
|
||||
if self.world_size > 1:
|
||||
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
|
||||
self.reducer.reduce_scatter_async(grad_chunks,
|
||||
group=self.reduce_scatter_process_group,
|
||||
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
self.reducer.reduce_scatter_async(
|
||||
grad_chunks,
|
||||
group=self.reduce_scatter_process_group,
|
||||
callback_fn=functools.partial(self._reduce_scatter_callback, param),
|
||||
)
|
||||
else:
|
||||
self._reduce_scatter_callback(param, grad)
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
|
||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||
assert isinstance(reduced_grad,
|
||||
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
|
||||
assert isinstance(
|
||||
reduced_grad, torch.Tensor
|
||||
), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
|
||||
reduced_grad.data = reduced_grad.data.contiguous().view(-1)
|
||||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
@@ -372,7 +380,6 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
# FIXME(ver217): refactor the below line when impl eviction policy
|
||||
def _save_grad(self, param: Parameter, grad: torch.Tensor):
|
||||
|
||||
# record whether we have overflow
|
||||
self.overflow_counter += torch.isinf(grad).any().item()
|
||||
self.overflow_counter += torch.isnan(grad).any().item()
|
||||
@@ -384,8 +391,9 @@ class ShardedModelV2(nn.Module):
|
||||
if self.reuse_fp16_shard:
|
||||
# make parameters point to gradient
|
||||
|
||||
assert param.colo_attr.saved_grad.is_null(
|
||||
), 'Gradient accumulation is not supported when reuse_fp16_shard=True'
|
||||
assert (
|
||||
param.colo_attr.saved_grad.is_null()
|
||||
), "Gradient accumulation is not supported when reuse_fp16_shard=True"
|
||||
|
||||
param.colo_attr.grad_payload_reset(grad.data)
|
||||
# release the memory of param
|
||||
@@ -396,7 +404,6 @@ class ShardedModelV2(nn.Module):
|
||||
if param.colo_attr.is_replicated:
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
else:
|
||||
|
||||
fp32_grad = cast_tensor_to_fp32(grad)
|
||||
|
||||
if param.colo_attr.saved_grad.is_null():
|
||||
@@ -410,39 +417,44 @@ class ShardedModelV2(nn.Module):
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
return self.module.parameters(recurse=recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
return self._colo_state_dict(destination,
|
||||
prefix,
|
||||
keep_vars,
|
||||
shard_strategy=self.shard_strategy,
|
||||
state_dict_func=nn.Module.state_dict,
|
||||
module_to_load=self.module,
|
||||
sharded_params=self.sharded_params,
|
||||
process_group=self.process_group)
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False) -> "OrderedDict[str, torch.Tensor]":
|
||||
return self._colo_state_dict(
|
||||
destination,
|
||||
prefix,
|
||||
keep_vars,
|
||||
shard_strategy=self.shard_strategy,
|
||||
state_dict_func=nn.Module.state_dict,
|
||||
module_to_load=self.module,
|
||||
sharded_params=self.sharded_params,
|
||||
process_group=self.process_group,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
|
||||
def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True) -> None:
|
||||
for name, p in self.named_parameters():
|
||||
if name in state_dict:
|
||||
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
|
||||
device=p.colo_attr.data_payload.device))
|
||||
p.colo_attr.data_payload_reset(
|
||||
state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)
|
||||
)
|
||||
# Force re-shard
|
||||
p.colo_attr.sharded_data_tensor.is_sharded = False
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
|
||||
elif strict:
|
||||
raise RuntimeError(f'Missing key in state_dict: {name}')
|
||||
raise RuntimeError(f"Missing key in state_dict: {name}")
|
||||
|
||||
def _colo_state_dict(self,
|
||||
destination=None,
|
||||
prefix='',
|
||||
keep_vars=False,
|
||||
shard_strategy: Optional[BaseShardStrategy] = None,
|
||||
state_dict_func=None,
|
||||
module_to_load=None,
|
||||
sharded_params=[],
|
||||
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
|
||||
def _colo_state_dict(
|
||||
self,
|
||||
destination=None,
|
||||
prefix="",
|
||||
keep_vars=False,
|
||||
shard_strategy: Optional[BaseShardStrategy] = None,
|
||||
state_dict_func=None,
|
||||
module_to_load=None,
|
||||
sharded_params=[],
|
||||
process_group=None,
|
||||
) -> "OrderedDict[str, torch.Tensor]":
|
||||
if len(sharded_params) == 0:
|
||||
for param in self.parameters():
|
||||
if param.colo_attr.param_is_sharded:
|
||||
@@ -460,15 +472,9 @@ class ShardedModelV2(nn.Module):
|
||||
p.colo_attr.set_data_none()
|
||||
return gathered_state_dict
|
||||
|
||||
def _colo_load_from_state_dict(self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
shard_strategy=None):
|
||||
def _colo_load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, shard_strategy=None
|
||||
):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
@@ -512,10 +518,12 @@ class ShardedModelV2(nn.Module):
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if hasattr(param, 'colo_attr'):
|
||||
if hasattr(param, "colo_attr"):
|
||||
param.colo_attr.data_payload_reset(
|
||||
input_param.to(dtype=param.colo_attr.data_payload.dtype,
|
||||
device=param.colo_attr.data_payload.device))
|
||||
input_param.to(
|
||||
dtype=param.colo_attr.data_payload.dtype, device=param.colo_attr.data_payload.device
|
||||
)
|
||||
)
|
||||
if shard_strategy is not None:
|
||||
# Force re-shard
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = False
|
||||
@@ -531,19 +539,21 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(
|
||||
key, input_param.shape, param.shape))
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
|
||||
)
|
||||
continue
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
|
||||
ex.args))
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
@@ -559,8 +569,8 @@ class ShardedModelV2(nn.Module):
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||
input_name = key[len(prefix) :]
|
||||
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
|
@@ -11,7 +11,7 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
|
||||
Note the other_model has to be the same as self.
|
||||
"""
|
||||
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
||||
assert hasattr(zero_param, 'colo_attr')
|
||||
assert hasattr(zero_param, "colo_attr")
|
||||
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
|
||||
|
@@ -20,11 +20,13 @@ class ZeroHook(BaseOpHook):
|
||||
Warning: this class has been deprecated after version 0.1.12
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
memstarts_collector: Optional[MemStatsCollector] = None,
|
||||
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None):
|
||||
def __init__(
|
||||
self,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
memstarts_collector: Optional[MemStatsCollector] = None,
|
||||
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger("ZeROHook")
|
||||
self.shard_strategy = shard_strategy
|
||||
@@ -41,7 +43,7 @@ class ZeroHook(BaseOpHook):
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert hasattr(param, "colo_attr")
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
|
||||
@@ -50,7 +52,7 @@ class ZeroHook(BaseOpHook):
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert hasattr(param, "colo_attr")
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
@@ -74,10 +76,9 @@ class ZeroHook(BaseOpHook):
|
||||
self.gather_parameters(module)
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
assert param.data.device.type == "cuda", f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
|
||||
# change tensor state to HOLD_AFTER_FWD
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
@@ -93,10 +94,9 @@ class ZeroHook(BaseOpHook):
|
||||
self.gather_parameters(module)
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
assert param.data.device.type == "cuda", f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
|
||||
# change tensor state to HOLD_AFTER_BWD
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
@@ -114,5 +114,6 @@ class ZeroHook(BaseOpHook):
|
||||
if self._stateful_tensor_mgr:
|
||||
self.logger.debug(
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
self._stateful_tensor_mgr.finish_iter()
|
||||
|
Reference in New Issue
Block a user