mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
[zero] Update sharded model v2 using sharded param v2 (#323)
This commit is contained in:
parent
799d105bb4
commit
1388671699
@ -15,8 +15,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
if type(outputs) is tuple:
|
||||
touched_outputs = []
|
||||
for output in outputs:
|
||||
touched_output = _apply_to_tensors_only(module, functional,
|
||||
backward_function, output)
|
||||
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
||||
touched_outputs.append(touched_output)
|
||||
return tuple(touched_outputs)
|
||||
elif type(outputs) is torch.Tensor:
|
||||
@ -26,6 +25,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
@ -41,6 +41,7 @@ class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
@ -60,9 +61,7 @@ class PostBackwardFunction(torch.autograd.Function):
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook] = None,
|
||||
name: str = ""):
|
||||
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
has_children = False
|
||||
@ -72,8 +71,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||
|
||||
# Early return on modules with no parameters or buffers that
|
||||
# are not in their children.
|
||||
if (len(list(module.named_parameters(recurse=False))) == 0
|
||||
and len(list(module.named_buffers(recurse=False))) == 0):
|
||||
if (len(list(module.named_parameters(recurse=False))) == 0 and len(list(module.named_buffers(recurse=False))) == 0):
|
||||
return
|
||||
|
||||
# return if the module has not childern.
|
||||
@ -95,22 +93,22 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_bwd_exec(submodule, inputs, output)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction,
|
||||
_run_before_backward_function, output)
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_bwd_exec(submodule, inputs)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction,
|
||||
_run_after_backward_function, inputs)
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
||||
|
||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||
module.register_forward_hook(_post_forward_module_hook)
|
||||
|
58
colossalai/engine/ophooks/zero_hook.py
Normal file
58
colossalai/engine/ophooks/zero_hook.py
Normal file
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ZeroHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param for ZeRO method.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_strategy: BaseShardStrategy):
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
self.shard_strategy.gather([param.col_attr.data])
|
||||
param.data = param.col_attr.data.payload
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
self.shard_strategy.shard([param.col_attr.data])
|
||||
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
self.shard_strategy.gather([param.col_attr.data])
|
||||
param.data = param.col_attr.data.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
# We haven't stored local accumulated grad yet
|
||||
assert param.col_attr.grad is None
|
||||
param.col_attr.grad = param.grad.data
|
||||
param.grad = None
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
# The grad here must be locally computed full grad in this backward pass
|
||||
assert param.grad.shape == param.col_attr.data.origin_shape
|
||||
param.col_attr.bwd_count += 1
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
self.shard_strategy.shard([param.col_attr.data])
|
||||
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
@ -1,6 +1,7 @@
|
||||
import functools
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
import torch
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
|
||||
@ -103,8 +104,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""
|
||||
if not self.rm_torch_payload_on_the_fly:
|
||||
for param in self.initialized_param_list:
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.remove_torch_payload()
|
||||
assert hasattr(param, 'col_attr')
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
||||
del self.initialized_param_list
|
||||
|
||||
@ -113,7 +114,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""
|
||||
for param in module.parameters():
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
if hasattr(param, 'ca_attr'):
|
||||
if hasattr(param, 'col_attr'):
|
||||
continue
|
||||
|
||||
if self.convert_cuda:
|
||||
@ -127,11 +128,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(torch.half).to(target_device)
|
||||
|
||||
param.ca_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor])
|
||||
if param.ca_attr.grad and self.shard_grad:
|
||||
self.shard_strategy.shard(tensor_list=[param.ca_attr._grad_sharded_tensor])
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||
if param.col_attr.grad and self.shard_grad:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||
|
@ -1,11 +1,10 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
@ -38,7 +37,7 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||
if i == self.local_rank:
|
||||
buffer_list.append(t.payload.cuda())
|
||||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel).cuda())
|
||||
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda())
|
||||
|
||||
torch.distributed.all_gather(buffer_list,
|
||||
buffer_list[self.local_rank],
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
@ -42,27 +41,21 @@ def free_storage(data: torch.Tensor) -> None:
|
||||
@torch.no_grad()
|
||||
def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
|
||||
"""Allocate storage for a tensor."""
|
||||
if data.storage().size() == size.numel(): # no need to reallocate
|
||||
if data.storage().size() == size.numel(): # no need to reallocate
|
||||
return
|
||||
assert data.storage().size() == 0
|
||||
data.storage().resize_(size.numel())
|
||||
|
||||
|
||||
def cast_trensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if tensor.dtype is torch.float32:
|
||||
out = tensor.half()
|
||||
if tensor.is_leaf:
|
||||
out.requires_grad = tensor.requires_grad
|
||||
return out
|
||||
def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||
return tensor.half()
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_trensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if tensor.dtype is torch.float16:
|
||||
out = tensor.float()
|
||||
if tensor.is_leaf:
|
||||
out.requires_grad = tensor.requires_grad
|
||||
return out
|
||||
def cast_tensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
|
||||
return tensor.float()
|
||||
return tensor
|
||||
|
||||
|
||||
@ -102,9 +95,8 @@ def assert_in_engine(cond: Any, s: Any) -> None:
|
||||
raise AssertionError
|
||||
|
||||
|
||||
def replace_state_dict_prefix(
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], old_prefix: str, new_prefix: str
|
||||
) -> None:
|
||||
def replace_state_dict_prefix(state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
||||
old_prefix: str, new_prefix: str) -> None:
|
||||
"""
|
||||
Replace all keys that match a given old_prefix with a new_prefix (in-place).
|
||||
|
||||
|
@ -5,8 +5,7 @@ import os
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from enum import Enum, auto
|
||||
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
|
||||
Set, Union)
|
||||
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Union)
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -15,16 +14,14 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from .param_manager import Zero3ParameterManager
|
||||
from torch.autograd import Variable
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._zero3_utils import (apply_to_tensors, assert_in_engine,
|
||||
cast_float_arguments, cast_trensor_to_fp16,
|
||||
cast_trensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor, get_shard,
|
||||
from ._zero3_utils import (apply_to_tensors, assert_in_engine, cast_float_arguments, cast_tensor_to_fp16,
|
||||
cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor, get_shard,
|
||||
replace_state_dict_prefix)
|
||||
from .param_manager import Zero3ParameterManager
|
||||
from .reduce_scatter import ReduceScatterBucketer
|
||||
|
||||
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
|
||||
@ -41,11 +38,13 @@ class TrainingState(Enum):
|
||||
POST_BACKWARD = auto()
|
||||
GATHER_FULL_PARAMS = auto()
|
||||
|
||||
|
||||
# TODO: Add clip_grad_norm_
|
||||
# TODO: Add gather_full_optim_state_dict and get_shard_from_optim_state_dict
|
||||
|
||||
|
||||
class ShardedModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
@ -96,8 +95,10 @@ class ShardedModel(nn.Module):
|
||||
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
# However, if you set gradient_predivide_factor to None
|
||||
# we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if \
|
||||
gradient_predivide_factor is not None else \
|
||||
get_gradient_predivide_factor(self.world_size)
|
||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||
|
||||
@ -111,8 +112,12 @@ class ShardedModel(nn.Module):
|
||||
|
||||
self.module = module
|
||||
|
||||
self.param_manager = Zero3ParameterManager(module, process_group=self.process_group, mixed_precision=self.mixed_precision,
|
||||
flatten_parameters=flatten_parameters, compute_dtype=self.compute_dtype, compute_device=self.compute_device,
|
||||
self.param_manager = Zero3ParameterManager(module,
|
||||
process_group=self.process_group,
|
||||
mixed_precision=self.mixed_precision,
|
||||
flatten_parameters=flatten_parameters,
|
||||
compute_dtype=self.compute_dtype,
|
||||
compute_device=self.compute_device,
|
||||
offload_config=offload_config)
|
||||
|
||||
self._reset_lazy_init_info()
|
||||
@ -145,13 +150,13 @@ class ShardedModel(nn.Module):
|
||||
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
|
||||
# the conversion).
|
||||
if self._is_root and self.mixed_precision:
|
||||
args, kwargs = cast_float_arguments(cast_trensor_to_fp16, *args, **kwargs)
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
|
||||
# If enabled, convert the input to FP32 if we are in full precision.
|
||||
# no_grad is not used because the input might be for a non-root instance,
|
||||
# which mean autograd needs to go through the conversion.
|
||||
if self.force_input_to_fp32 and not self.mixed_precision:
|
||||
args, kwargs = cast_float_arguments(cast_trensor_to_fp32, *args, **kwargs)
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp32, *args, **kwargs)
|
||||
|
||||
# All-gather full parameters. This will also transfer FP32 parameters to
|
||||
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
|
||||
@ -201,10 +206,9 @@ class ShardedModel(nn.Module):
|
||||
input_tensor = torch.ones(1).to(self.compute_device)
|
||||
output = list(torch.zeros(self.world_size).to(self.compute_device).chunk(self.world_size))
|
||||
dist.all_gather(output, input_tensor, group=self.process_group)
|
||||
assert torch.cat(output).sum() == float(self.world_size), (
|
||||
f"found {torch.cat(output).sum()} devices in process group but "
|
||||
f"world_size={self.world_size}. Check torch.cuda.set_device is called properly"
|
||||
)
|
||||
assert torch.cat(output).sum() == float(
|
||||
self.world_size), (f"found {torch.cat(output).sum()} devices in process group but "
|
||||
f"world_size={self.world_size}. Check torch.cuda.set_device is called properly")
|
||||
|
||||
def _reset_lazy_init_info(self) -> None:
|
||||
self._is_root: Optional[bool] = None
|
||||
@ -277,9 +281,10 @@ class ShardedModel(nn.Module):
|
||||
|
||||
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
|
||||
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
|
||||
m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
|
||||
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
|
||||
)
|
||||
m.no_broadcast_optim_state = m.no_broadcast_optim_state or \
|
||||
((m.world_size == 1)
|
||||
and (m.world_size < self.world_size)
|
||||
and (m.process_group != self.process_group))
|
||||
|
||||
def _setup_streams(self) -> None:
|
||||
"""Create streams to overlap data transfer and computation."""
|
||||
@ -330,9 +335,10 @@ class ShardedModel(nn.Module):
|
||||
else:
|
||||
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
|
||||
|
||||
def _cast_buffers(
|
||||
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
|
||||
) -> None:
|
||||
def _cast_buffers(self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
memo: Optional[Set] = None) -> None:
|
||||
"""Move all buffers to the given *device* and *dtype*.
|
||||
|
||||
If *device* or *dtype* are not given, then they will default to
|
||||
@ -398,7 +404,7 @@ class ShardedModel(nn.Module):
|
||||
outputs: new outputs with hooks registered if they requires gradient.
|
||||
"""
|
||||
if not torch.is_grad_enabled():
|
||||
return outputs # don't register hooks if grad isn't enabled
|
||||
return outputs # don't register hooks if grad isn't enabled
|
||||
|
||||
if self._is_root:
|
||||
# This actually means that only root instance has
|
||||
@ -523,7 +529,7 @@ class ShardedModel(nn.Module):
|
||||
a new hook, which is needed for a new forward pass.
|
||||
"""
|
||||
if not torch.is_grad_enabled():
|
||||
return # don't register grad hooks if grad isn't enabled
|
||||
return # don't register grad hooks if grad isn't enabled
|
||||
for p in self.params:
|
||||
if p.requires_grad:
|
||||
if hasattr(p, "zero_shard_bwd_hook"):
|
||||
@ -612,7 +618,8 @@ class ShardedModel(nn.Module):
|
||||
if param.zero_is_sharded:
|
||||
assert self._reducer is not None
|
||||
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
|
||||
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times it's possible that multiple
|
||||
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times
|
||||
# it's possible that multiple
|
||||
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
|
||||
# matter, neglecting rounding.
|
||||
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
|
||||
@ -628,9 +635,9 @@ class ShardedModel(nn.Module):
|
||||
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
|
||||
callback_fn = functools.partial(self._reduce_scatter_callback, param)
|
||||
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||
self._reducer.reduce_scatter_async(
|
||||
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=callback_fn
|
||||
)
|
||||
self._reducer.reduce_scatter_async(grad_chunks,
|
||||
group=self.reduce_scatter_process_group,
|
||||
callback_fn=callback_fn)
|
||||
else:
|
||||
# Currently the only way for _is_sharded to be False is if
|
||||
# world_size == 1. This could be relaxed in the future, in which
|
||||
@ -667,8 +674,9 @@ class ShardedModel(nn.Module):
|
||||
param.zero_saved_grad_shard = reduced_grad.data
|
||||
else:
|
||||
assert (
|
||||
param.zero_saved_grad_shard.shape == reduced_grad.shape
|
||||
), f"{param.zero_saved_grad_shard.shape} vs {reduced_grad.shape}"
|
||||
param.zero_saved_grad_shard.shape == reduced_grad.shape), f"{param.zero_saved_grad_shard.shape} \
|
||||
vs {reduced_grad.shape}"
|
||||
|
||||
param.zero_saved_grad_shard.data += reduced_grad.data
|
||||
reduced_grad = param.zero_saved_grad_shard.data
|
||||
else:
|
||||
@ -717,7 +725,7 @@ class ShardedModel(nn.Module):
|
||||
# Flush any unreduced buckets in the post_backward stream.
|
||||
with torch.cuda.stream(self._streams["post_backward"]):
|
||||
assert_in_engine(self._reducer is not None, "FinalBackwardHook: reducer is None")
|
||||
assert self._reducer is not None # make mypy happy
|
||||
assert self._reducer is not None # make mypy happy
|
||||
self._reducer.flush()
|
||||
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
|
||||
if self._cpu_offload:
|
||||
@ -753,7 +761,8 @@ class ShardedModel(nn.Module):
|
||||
elif hasattr(p, "zero_saved_grad_shard"):
|
||||
assert_in_engine(
|
||||
p.device == p.zero_saved_grad_shard.device,
|
||||
f"FinalBackwardHook: incorrect saved_grad_shard device {p.device} vs {p.zero_saved_grad_shard.device}",
|
||||
f"FinalBackwardHook: incorrect saved_grad_shard device \
|
||||
{p.device} vs {p.zero_saved_grad_shard.device}",
|
||||
)
|
||||
p.grad = p.zero_saved_grad_shard
|
||||
elif hasattr(p, 'zero_saved_grad'):
|
||||
@ -765,7 +774,7 @@ class ShardedModel(nn.Module):
|
||||
delattr(p, "zero_saved_grad")
|
||||
|
||||
# Update root and nested ShardedModel's hooks and flags.
|
||||
for m in self.modules(): # includes self
|
||||
for m in self.modules(): # includes self
|
||||
if isinstance(m, ShardedModel):
|
||||
_finalize_parameters(m)
|
||||
m._pre_backward_hook_has_run = False
|
||||
@ -796,7 +805,7 @@ class ShardedModel(nn.Module):
|
||||
self._output_pre_backward_hook_registered is not None,
|
||||
"FinalBackwardHook: self._output_pre_backward_hook_registered should not be None",
|
||||
)
|
||||
assert self._output_pre_backward_hook_registered is not None # make mypy happy
|
||||
assert self._output_pre_backward_hook_registered is not None # make mypy happy
|
||||
self._output_pre_backward_hook_registered.clear()
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -908,9 +917,9 @@ class ShardedModel(nn.Module):
|
||||
state["is_sharded"] = [p.zero_is_sharded for p in self.params]
|
||||
state["orig_sizes"] = [p.zero_orig_size for p in self.params]
|
||||
if state["process_group"] is not None:
|
||||
state["process_group"] = "MISSING" # process_group isn't pickleable
|
||||
state["process_group"] = "MISSING" # process_group isn't pickleable
|
||||
if state["process_group_reduce_scatter"] is not None:
|
||||
state["process_group_reduce_scatter"] = "MISSING" # process_group_reduce_scatter isn't pickleable
|
||||
state["process_group_reduce_scatter"] = "MISSING" # process_group_reduce_scatter isn't pickleable
|
||||
self._reset_lazy_init_info()
|
||||
return state
|
||||
|
||||
@ -920,7 +929,7 @@ class ShardedModel(nn.Module):
|
||||
|
||||
def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
|
||||
assert isinstance(p, Parameter)
|
||||
p.data = p.data.clone() # move tensors out of shared memory
|
||||
p.data = p.data.clone() # move tensors out of shared memory
|
||||
p.zero_is_sharded = is_sharded
|
||||
p.zero_orig_size = size
|
||||
return p
|
||||
@ -958,7 +967,7 @@ class ShardedModel(nn.Module):
|
||||
# This instance may wrap other ShardedModel instances and we
|
||||
# need to set all of them to accumulate gradients.
|
||||
old_flags = []
|
||||
for m in self.modules(): # includes self
|
||||
for m in self.modules(): # includes self
|
||||
if isinstance(m, ShardedModel):
|
||||
old_flags.append((m, m._require_backward_grad_sync))
|
||||
m._require_backward_grad_sync = False
|
||||
@ -986,22 +995,18 @@ class ShardedModel(nn.Module):
|
||||
raise ValueError(msg)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
repr = (
|
||||
f"world_size={self.world_size}, "
|
||||
f"mixed_precision={self.mixed_precision}, "
|
||||
)
|
||||
repr = (f"world_size={self.world_size}, "
|
||||
f"mixed_precision={self.mixed_precision}, ")
|
||||
if self.verbose:
|
||||
repr = (
|
||||
f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
|
||||
f"compute_dtype={self.compute_dtype}, "
|
||||
f"buffer_dtype={self.buffer_dtype}, "
|
||||
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
|
||||
f"compute_device={self.compute_device}"
|
||||
f"reduce_scatter_bucket_size_mb={self.reduce_scatter_bucket_size_mb}, "
|
||||
f"clear_autocast_cache={self.clear_autocast_cache}"
|
||||
f"force_input_to_fp32={self.force_input_to_fp32}"
|
||||
f"offload_config={self.offload_config}"
|
||||
)
|
||||
repr = (f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
|
||||
f"compute_dtype={self.compute_dtype}, "
|
||||
f"buffer_dtype={self.buffer_dtype}, "
|
||||
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
|
||||
f"compute_device={self.compute_device}"
|
||||
f"reduce_scatter_bucket_size_mb={self.reduce_scatter_bucket_size_mb}, "
|
||||
f"clear_autocast_cache={self.clear_autocast_cache}"
|
||||
f"force_input_to_fp32={self.force_input_to_fp32}"
|
||||
f"offload_config={self.offload_config}")
|
||||
return repr
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
@ -1039,9 +1044,9 @@ class ShardedModel(nn.Module):
|
||||
maybe_cast_buffers()
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
|
||||
) -> NamedTuple:
|
||||
def load_state_dict(self,
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
||||
strict: bool = True) -> NamedTuple:
|
||||
"""
|
||||
Load a whole (unsharded) state_dict.
|
||||
|
||||
@ -1094,7 +1099,6 @@ def _post_state_dict_hook(
|
||||
return state_dict
|
||||
|
||||
|
||||
def _pre_load_state_dict_hook(
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
|
||||
) -> None:
|
||||
def _pre_load_state_dict_hook(state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str,
|
||||
*args: Any) -> None:
|
||||
replace_state_dict_prefix(state_dict, prefix, prefix + "_zero3_module.")
|
||||
|
@ -1,4 +1,5 @@
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@ -6,32 +7,32 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively
|
||||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
||||
from colossalai.zero.sharded_param import ShardedParam
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||
get_gradient_predivide_factor)
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
reshard_after_forward: bool = True,
|
||||
mixed_precision: bool = False,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
):
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
shard_param: bool = True):
|
||||
r"""
|
||||
A demo to reconfigure zero1 shared_model.
|
||||
Currently do not consider the Optimizer States.
|
||||
@ -44,22 +45,24 @@ class ShardedModelV2(nn.Module):
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.rank = dist.get_rank(self.process_group)
|
||||
|
||||
# The module has to be placed on GPU
|
||||
self.module = module.cuda()
|
||||
# Cast module to fp16 and cuda, in case user didn't use ZeroInitContext
|
||||
self.module = module.half().cuda()
|
||||
|
||||
# Shard the parameters at first
|
||||
for _, param in self.module.named_parameters():
|
||||
param.ca_attr = ShardedParam(param)
|
||||
param.ca_attr.shard()
|
||||
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
||||
self.shard_strategy = shard_strategy
|
||||
self.shard_param = shard_param
|
||||
|
||||
# In case user didn't use ZeroInitContext
|
||||
for param in self.module.parameters():
|
||||
if not hasattr(param, 'col_attr'):
|
||||
param.col_attr = ShardedParamV2(param, process_group)
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.data])
|
||||
|
||||
# Register hooks
|
||||
register_ophooks_recursively(self.module, [ShardParamHook(), ShardGradHook()])
|
||||
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy)])
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.reshard_after_forward = reshard_after_forward
|
||||
self.mixed_precision = mixed_precision
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
@ -76,6 +79,7 @@ class ShardedModelV2(nn.Module):
|
||||
self._require_backward_grad_sync: bool = True
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
outputs = self.module(*args, **kwargs)
|
||||
return outputs
|
||||
|
||||
@ -99,6 +103,7 @@ class ShardedModelV2(nn.Module):
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.reducer.free()
|
||||
for p in self.module.parameters():
|
||||
p.col_attr.bwd_count = 0
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
|
||||
@ -107,11 +112,14 @@ class ShardedModelV2(nn.Module):
|
||||
# sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
p._sharded_grad.write_back()
|
||||
# Write grad back to p.grad and set p.col_attr.grad to None
|
||||
p.grad.data = p.col_attr.grad
|
||||
p.col_attr.grad = None
|
||||
# In case some post bwd hook is not fired
|
||||
for p in self.module.parameters():
|
||||
if not p.ca_attr.is_sharded:
|
||||
p.ca_attr.shard()
|
||||
if self.shard_param:
|
||||
for p in self.module.parameters():
|
||||
if not p.col_attr.param_is_sharded:
|
||||
self.shard_strategy.shard([p.col_attr.data])
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
@ -119,7 +127,7 @@ class ShardedModelV2(nn.Module):
|
||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||
full gradient for the local batch. The reduce-scatter op will save
|
||||
a single shard of the summed gradient across all
|
||||
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
||||
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
|
||||
|
||||
before reduce_scatter:
|
||||
param.grad (GPU #0): [1, 2, 3, 4]
|
||||
@ -131,7 +139,7 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
The local GPU's ``optim.step`` is responsible for updating a single
|
||||
shard of params, also corresponding to the current GPU's rank. This
|
||||
alignment is created by `param._sharded_grad`, which ensures that
|
||||
alignment is created by `param.col_attr.grad`, which ensures that
|
||||
the local optimizer only sees the relevant parameter shard.
|
||||
"""
|
||||
if grad is None:
|
||||
@ -142,7 +150,7 @@ class ShardedModelV2(nn.Module):
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
new_grad = grad.clone()
|
||||
if self.mixed_precision and self.fp32_reduce_scatter:
|
||||
if self.fp32_reduce_scatter:
|
||||
new_grad.data = new_grad.data.to(param.dtype)
|
||||
if self.gradient_predivide_factor > 1.0:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
@ -161,13 +169,30 @@ class ShardedModelV2(nn.Module):
|
||||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
# Cast grad to param's dtype (typically FP32). Note: we do this
|
||||
# before the cpu offload step so that this entire hook remains
|
||||
# non-blocking. The downside is a bit more D2H transfer in that case.
|
||||
if self.mixed_precision:
|
||||
orig_param_grad_data = reduced_grad.data
|
||||
reduced_grad.data = reduced_grad.data.to(dtype=param.ca_attr.origin_dtype)
|
||||
# Don't let this memory get reused until after the transfer.
|
||||
orig_param_grad_data.record_stream(torch.cuda.current_stream())
|
||||
|
||||
param._sharded_grad.reduce_scatter_callback(reduced_grad)
|
||||
# Make sure we store fp32 grad
|
||||
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
|
||||
|
||||
# Maybe offload
|
||||
if self._cpu_offload:
|
||||
reduced_grad.data = reduced_grad.data.cpu()
|
||||
|
||||
if param.col_attr.grad is None:
|
||||
param.col_attr.grad = reduced_grad.data
|
||||
else:
|
||||
param.col_attr.grad.add_(reduced_grad.data)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])
|
||||
prev_params = {}
|
||||
for p in self.module.parameters():
|
||||
prev_params[p] = p.data
|
||||
p.data = p.col_attr.data.payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()])
|
||||
for p in self.module.parameters():
|
||||
p.data = prev_params[p]
|
||||
return gathered_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
raise NotImplementedError
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -5,7 +7,6 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Union, Tuple, Optional
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
@ -14,12 +15,8 @@ class ShardedParamV2(object):
|
||||
param: torch.nn.Parameter,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rm_torch_payload=False) -> None:
|
||||
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
|
||||
if param.requires_grad and param.grad is not None:
|
||||
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
|
||||
param.grad = None
|
||||
else:
|
||||
self._grad_sharded_tensor = None
|
||||
self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||
self._grad_sharded_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# make sure the shared param is the only owner of payload
|
||||
# The param.data maybe used to init the other part of the model.
|
||||
@ -30,27 +27,29 @@ class ShardedParamV2(object):
|
||||
if rm_torch_payload:
|
||||
self.remove_torch_payload()
|
||||
|
||||
# Backward count for handle local grad accumulation
|
||||
# This value will increment by 1 in every pre-bwd hook
|
||||
# And will be reset to 0 in every final-bwd hook
|
||||
self.bwd_count = 0
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data_sharded_tensor.payload
|
||||
|
||||
@data.setter
|
||||
def data(self, t: torch.Tensor):
|
||||
self._data_sharded_tensor.payload = t
|
||||
return self._data_sharded_tensor
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
if self._grad_sharded_tensor:
|
||||
return self._grad_sharded_tensor.payload
|
||||
else:
|
||||
return None
|
||||
return self._grad_sharded_tensor
|
||||
|
||||
@grad.setter
|
||||
def grad(self, t: torch.Tensor):
|
||||
self._grad_sharded_tensor.payload = t
|
||||
self._grad_sharded_tensor = t
|
||||
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self._data_sharded_tensor.is_sharded
|
||||
|
||||
|
||||
class ShardedParam(object):
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@ -45,16 +45,16 @@ class Net(nn.Module):
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
assert p.grad.dtype == zero_grad.dtype
|
||||
assert allclose(p.grad, zero_grad, loose=loose)
|
||||
LOGGER.info(torch.sum(p.grad - zero_grad))
|
||||
grad = p.grad.float()
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
@ -71,11 +71,11 @@ def check_grads_padding(model, zero_model, loose=False):
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
grad = chunks[rank].float()
|
||||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
assert allclose(grad, zero_grad, loose=loose), f'{grad} vs {zero_grad}'
|
||||
|
||||
|
||||
def check_params_padding(model, zero_model, loose=False):
|
||||
|
@ -7,12 +7,14 @@ import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from common import CONFIG
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||
TensorShardStrategy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG, Net
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@ -25,11 +27,11 @@ def run_dist(rank, world_size, port):
|
||||
shard_param=True):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert param.ca_attr.data.dtype == torch.half
|
||||
assert param.ca_attr._data_sharded_tensor.is_sharded
|
||||
assert param.ca_attr.data.device.type == 'cuda'
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
assert param.col_attr.data.dtype == torch.half
|
||||
assert param.col_attr.data.is_sharded
|
||||
assert param.col_attr.data.payload.device.type == 'cuda'
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@ -9,19 +9,21 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||
TensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, Net, check_grads, check_grads_padding
|
||||
from common import CONFIG, check_grads, check_grads_padding
|
||||
|
||||
|
||||
def run_fwd_bwd(model, x, enable_autocast=False):
|
||||
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
@ -31,19 +33,26 @@ def run_fwd_bwd(model, x, enable_autocast=False):
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
model = Net(checkpoint=True).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
for _ in range(2):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_fwd_bwd(zero_model, x, False)
|
||||
run_fwd_bwd(model, x, False)
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
shard_strategy = TensorShardStrategy()
|
||||
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||
model = model().half().cuda()
|
||||
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
|
||||
if dist.get_world_size() > 1:
|
||||
check_grads_padding(model, zero_model)
|
||||
else:
|
||||
check_grads(model, zero_model)
|
||||
model = DDP(model)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
data, label = data.half().cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, False)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, False)
|
||||
if dist.get_world_size() > 1:
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
else:
|
||||
check_grads(model, zero_model, loose=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@ -4,18 +4,16 @@
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||
|
||||
from tests.test_zero_data_parallel.common import Net, CONFIG, allclose
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from tests.test_zero_data_parallel.common import CONFIG, Net, allclose
|
||||
|
||||
|
||||
def _run_shard_tensor(rank, world_size, port):
|
||||
@ -47,7 +45,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||
param_ref = deepcopy(param)
|
||||
sparam = ShardedParamV2(param=param, process_group=None)
|
||||
|
||||
allclose(sparam.data, param_ref.data)
|
||||
allclose(sparam.data.payload, param_ref.data)
|
||||
|
||||
sparam.remove_torch_payload()
|
||||
assert (param.data.numel() == 1)
|
||||
|
73
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
Normal file
73
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||
TensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_grads, check_grads_padding
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
shard_strategy = TensorShardStrategy()
|
||||
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True):
|
||||
zero_model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||
zero_model = zero_model()
|
||||
model = copy.deepcopy(zero_model)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
model_state_dict = zero_model.state_dict()
|
||||
for n, p in model.named_parameters():
|
||||
p.data = model_state_dict[n]
|
||||
model = model.half().cuda()
|
||||
if dist.get_world_size() > 1:
|
||||
model = DDP(model)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
data, label = data.half().cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, False)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, False)
|
||||
if dist.get_world_size() > 1:
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
else:
|
||||
check_grads(model, zero_model, loose=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_shard_model_v2():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_model_v2()
|
@ -56,7 +56,7 @@ def run_dist(rank, world_size, port):
|
||||
check_params(model, zero_model)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip
|
||||
def test_sharded_optim_v2():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
|
43
tests/test_zero_data_parallel/test_state_dict.py
Normal file
43
tests/test_zero_data_parallel/test_state_dict.py
Normal file
@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||
TensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||
model = model()
|
||||
shard_strategy = TensorShardStrategy()
|
||||
model = model.half().cuda()
|
||||
zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
|
||||
zero_state_dict = zero_model.state_dict()
|
||||
for key, val in model.state_dict().items():
|
||||
assert torch.equal(val, zero_state_dict[key])
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_state_dict():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_state_dict()
|
Loading…
Reference in New Issue
Block a user