[zero] Update sharded model v2 using sharded param v2 (#323)

This commit is contained in:
ver217 2022-03-08 18:18:06 +08:00 committed by Frank Lee
parent 799d105bb4
commit 1388671699
16 changed files with 403 additions and 202 deletions

View File

@ -15,8 +15,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional,
backward_function, output)
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
@ -26,6 +25,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
@ -41,6 +41,7 @@ class PreBackwardFunction(torch.autograd.Function):
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
@ -60,9 +61,7 @@ class PostBackwardFunction(torch.autograd.Function):
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = ""):
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
has_children = False
@ -72,8 +71,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
# Early return on modules with no parameters or buffers that
# are not in their children.
if (len(list(module.named_parameters(recurse=False))) == 0
and len(list(module.named_buffers(recurse=False))) == 0):
if (len(list(module.named_parameters(recurse=False))) == 0 and len(list(module.named_buffers(recurse=False))) == 0):
return
# return if the module has not childern.
@ -95,22 +93,22 @@ def register_ophooks_recursively(module: torch.nn.Module,
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction,
_run_before_backward_function, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction,
_run_after_backward_function, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)

View File

@ -0,0 +1,58 @@
import torch
from colossalai.registry import OPHOOKS
from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook
@OPHOOKS.register_module
class ZeroHook(BaseOpHook):
"""
A hook to process sharded param for ZeRO method.
"""
def __init__(self, shard_strategy: BaseShardStrategy):
super().__init__()
self.shard_strategy = shard_strategy
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data])
param.data = param.col_attr.data.payload
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.shard([param.col_attr.data])
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data])
param.data = param.col_attr.data.payload
# Store local accumulated grad shard
if param.grad is not None:
if param.col_attr.bwd_count == 0:
# We haven't stored local accumulated grad yet
assert param.col_attr.grad is None
param.col_attr.grad = param.grad.data
param.grad = None
else:
# We have stored local accumulated grad
# The grad here must be locally computed full grad in this backward pass
assert param.grad.shape == param.col_attr.data.origin_shape
param.col_attr.bwd_count += 1
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.shard([param.col_attr.data])
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
def pre_iter(self):
pass
def post_iter(self):
pass

View File

@ -1,6 +1,7 @@
import functools
from colossalai.utils.cuda import get_current_device
import torch
from colossalai.utils.cuda import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2
@ -103,8 +104,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
if not self.rm_torch_payload_on_the_fly:
for param in self.initialized_param_list:
assert hasattr(param, 'ca_attr')
param.ca_attr.remove_torch_payload()
assert hasattr(param, 'col_attr')
param.col_attr.remove_torch_payload()
del self.initialized_param_list
@ -113,7 +114,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
for param in module.parameters():
# avoid adapting a param to ShardedParam twice
if hasattr(param, 'ca_attr'):
if hasattr(param, 'col_attr'):
continue
if self.convert_cuda:
@ -127,11 +128,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None:
param.grad = param.grad.to(torch.half).to(target_device)
param.ca_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
self.initialized_param_list.append(param)
if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor])
if param.ca_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.ca_attr._grad_sharded_tensor])
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
if param.col_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])

View File

@ -1,11 +1,10 @@
import torch
import torch.distributed as dist
from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy):
@ -38,7 +37,7 @@ class TensorShardStrategy(BaseShardStrategy):
if i == self.local_rank:
buffer_list.append(t.payload.cuda())
else:
buffer_list.append(torch.zeros(payload_numel).cuda())
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda())
torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank],

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Tuple, Union
@ -42,27 +41,21 @@ def free_storage(data: torch.Tensor) -> None:
@torch.no_grad()
def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
"""Allocate storage for a tensor."""
if data.storage().size() == size.numel(): # no need to reallocate
if data.storage().size() == size.numel(): # no need to reallocate
return
assert data.storage().size() == 0
data.storage().resize_(size.numel())
def cast_trensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype is torch.float32:
out = tensor.half()
if tensor.is_leaf:
out.requires_grad = tensor.requires_grad
return out
def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
return tensor.half()
return tensor
def cast_trensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype is torch.float16:
out = tensor.float()
if tensor.is_leaf:
out.requires_grad = tensor.requires_grad
return out
def cast_tensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
return tensor.float()
return tensor
@ -102,9 +95,8 @@ def assert_in_engine(cond: Any, s: Any) -> None:
raise AssertionError
def replace_state_dict_prefix(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], old_prefix: str, new_prefix: str
) -> None:
def replace_state_dict_prefix(state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
old_prefix: str, new_prefix: str) -> None:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).

View File

@ -5,8 +5,7 @@ import os
import traceback
from collections import OrderedDict
from enum import Enum, auto
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
Set, Union)
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Union)
import torch
import torch.distributed as dist
@ -15,16 +14,14 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .param_manager import Zero3ParameterManager
from torch.autograd import Variable
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (apply_to_tensors, assert_in_engine,
cast_float_arguments, cast_trensor_to_fp16,
cast_trensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor, get_shard,
from ._zero3_utils import (apply_to_tensors, assert_in_engine, cast_float_arguments, cast_tensor_to_fp16,
cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor, get_shard,
replace_state_dict_prefix)
from .param_manager import Zero3ParameterManager
from .reduce_scatter import ReduceScatterBucketer
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
@ -41,11 +38,13 @@ class TrainingState(Enum):
POST_BACKWARD = auto()
GATHER_FULL_PARAMS = auto()
# TODO: Add clip_grad_norm_
# TODO: Add gather_full_optim_state_dict and get_shard_from_optim_state_dict
class ShardedModel(nn.Module):
def __init__(self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
@ -96,8 +95,10 @@ class ShardedModel(nn.Module):
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
# However, if you set gradient_predivide_factor to None
# we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if \
gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
@ -111,8 +112,12 @@ class ShardedModel(nn.Module):
self.module = module
self.param_manager = Zero3ParameterManager(module, process_group=self.process_group, mixed_precision=self.mixed_precision,
flatten_parameters=flatten_parameters, compute_dtype=self.compute_dtype, compute_device=self.compute_device,
self.param_manager = Zero3ParameterManager(module,
process_group=self.process_group,
mixed_precision=self.mixed_precision,
flatten_parameters=flatten_parameters,
compute_dtype=self.compute_dtype,
compute_device=self.compute_device,
offload_config=offload_config)
self._reset_lazy_init_info()
@ -145,13 +150,13 @@ class ShardedModel(nn.Module):
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
if self._is_root and self.mixed_precision:
args, kwargs = cast_float_arguments(cast_trensor_to_fp16, *args, **kwargs)
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if self.force_input_to_fp32 and not self.mixed_precision:
args, kwargs = cast_float_arguments(cast_trensor_to_fp32, *args, **kwargs)
args, kwargs = cast_float_arguments(cast_tensor_to_fp32, *args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
@ -201,10 +206,9 @@ class ShardedModel(nn.Module):
input_tensor = torch.ones(1).to(self.compute_device)
output = list(torch.zeros(self.world_size).to(self.compute_device).chunk(self.world_size))
dist.all_gather(output, input_tensor, group=self.process_group)
assert torch.cat(output).sum() == float(self.world_size), (
f"found {torch.cat(output).sum()} devices in process group but "
f"world_size={self.world_size}. Check torch.cuda.set_device is called properly"
)
assert torch.cat(output).sum() == float(
self.world_size), (f"found {torch.cat(output).sum()} devices in process group but "
f"world_size={self.world_size}. Check torch.cuda.set_device is called properly")
def _reset_lazy_init_info(self) -> None:
self._is_root: Optional[bool] = None
@ -277,9 +281,10 @@ class ShardedModel(nn.Module):
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
)
m.no_broadcast_optim_state = m.no_broadcast_optim_state or \
((m.world_size == 1)
and (m.world_size < self.world_size)
and (m.process_group != self.process_group))
def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation."""
@ -330,9 +335,10 @@ class ShardedModel(nn.Module):
else:
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
def _cast_buffers(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
) -> None:
def _cast_buffers(self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
memo: Optional[Set] = None) -> None:
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
@ -398,7 +404,7 @@ class ShardedModel(nn.Module):
outputs: new outputs with hooks registered if they requires gradient.
"""
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled
return outputs # don't register hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has
@ -523,7 +529,7 @@ class ShardedModel(nn.Module):
a new hook, which is needed for a new forward pass.
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
return # don't register grad hooks if grad isn't enabled
for p in self.params:
if p.requires_grad:
if hasattr(p, "zero_shard_bwd_hook"):
@ -612,7 +618,8 @@ class ShardedModel(nn.Module):
if param.zero_is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times it's possible that multiple
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times
# it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
@ -628,9 +635,9 @@ class ShardedModel(nn.Module):
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
callback_fn = functools.partial(self._reduce_scatter_callback, param)
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
self._reducer.reduce_scatter_async(
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=callback_fn
)
self._reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=callback_fn)
else:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
@ -667,8 +674,9 @@ class ShardedModel(nn.Module):
param.zero_saved_grad_shard = reduced_grad.data
else:
assert (
param.zero_saved_grad_shard.shape == reduced_grad.shape
), f"{param.zero_saved_grad_shard.shape} vs {reduced_grad.shape}"
param.zero_saved_grad_shard.shape == reduced_grad.shape), f"{param.zero_saved_grad_shard.shape} \
vs {reduced_grad.shape}"
param.zero_saved_grad_shard.data += reduced_grad.data
reduced_grad = param.zero_saved_grad_shard.data
else:
@ -717,7 +725,7 @@ class ShardedModel(nn.Module):
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]):
assert_in_engine(self._reducer is not None, "FinalBackwardHook: reducer is None")
assert self._reducer is not None # make mypy happy
assert self._reducer is not None # make mypy happy
self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self._cpu_offload:
@ -753,7 +761,8 @@ class ShardedModel(nn.Module):
elif hasattr(p, "zero_saved_grad_shard"):
assert_in_engine(
p.device == p.zero_saved_grad_shard.device,
f"FinalBackwardHook: incorrect saved_grad_shard device {p.device} vs {p.zero_saved_grad_shard.device}",
f"FinalBackwardHook: incorrect saved_grad_shard device \
{p.device} vs {p.zero_saved_grad_shard.device}",
)
p.grad = p.zero_saved_grad_shard
elif hasattr(p, 'zero_saved_grad'):
@ -765,7 +774,7 @@ class ShardedModel(nn.Module):
delattr(p, "zero_saved_grad")
# Update root and nested ShardedModel's hooks and flags.
for m in self.modules(): # includes self
for m in self.modules(): # includes self
if isinstance(m, ShardedModel):
_finalize_parameters(m)
m._pre_backward_hook_has_run = False
@ -796,7 +805,7 @@ class ShardedModel(nn.Module):
self._output_pre_backward_hook_registered is not None,
"FinalBackwardHook: self._output_pre_backward_hook_registered should not be None",
)
assert self._output_pre_backward_hook_registered is not None # make mypy happy
assert self._output_pre_backward_hook_registered is not None # make mypy happy
self._output_pre_backward_hook_registered.clear()
@contextlib.contextmanager
@ -908,9 +917,9 @@ class ShardedModel(nn.Module):
state["is_sharded"] = [p.zero_is_sharded for p in self.params]
state["orig_sizes"] = [p.zero_orig_size for p in self.params]
if state["process_group"] is not None:
state["process_group"] = "MISSING" # process_group isn't pickleable
state["process_group"] = "MISSING" # process_group isn't pickleable
if state["process_group_reduce_scatter"] is not None:
state["process_group_reduce_scatter"] = "MISSING" # process_group_reduce_scatter isn't pickleable
state["process_group_reduce_scatter"] = "MISSING" # process_group_reduce_scatter isn't pickleable
self._reset_lazy_init_info()
return state
@ -920,7 +929,7 @@ class ShardedModel(nn.Module):
def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
assert isinstance(p, Parameter)
p.data = p.data.clone() # move tensors out of shared memory
p.data = p.data.clone() # move tensors out of shared memory
p.zero_is_sharded = is_sharded
p.zero_orig_size = size
return p
@ -958,7 +967,7 @@ class ShardedModel(nn.Module):
# This instance may wrap other ShardedModel instances and we
# need to set all of them to accumulate gradients.
old_flags = []
for m in self.modules(): # includes self
for m in self.modules(): # includes self
if isinstance(m, ShardedModel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
@ -986,22 +995,18 @@ class ShardedModel(nn.Module):
raise ValueError(msg)
def extra_repr(self) -> str:
repr = (
f"world_size={self.world_size}, "
f"mixed_precision={self.mixed_precision}, "
)
repr = (f"world_size={self.world_size}, "
f"mixed_precision={self.mixed_precision}, ")
if self.verbose:
repr = (
f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}"
f"reduce_scatter_bucket_size_mb={self.reduce_scatter_bucket_size_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
f"offload_config={self.offload_config}"
)
repr = (f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}"
f"reduce_scatter_bucket_size_mb={self.reduce_scatter_bucket_size_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
f"offload_config={self.offload_config}")
return repr
def state_dict(self, destination=None, prefix='', keep_vars=False):
@ -1039,9 +1044,9 @@ class ShardedModel(nn.Module):
maybe_cast_buffers()
return state_dict
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
def load_state_dict(self,
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
strict: bool = True) -> NamedTuple:
"""
Load a whole (unsharded) state_dict.
@ -1094,7 +1099,6 @@ def _post_state_dict_hook(
return state_dict
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
def _pre_load_state_dict_hook(state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str,
*args: Any) -> None:
replace_state_dict_prefix(state_dict, prefix, prefix + "_zero3_module.")

View File

@ -1,4 +1,5 @@
import functools
from collections import OrderedDict
from typing import Any, Optional
import torch
@ -6,32 +7,32 @@ import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
from colossalai.zero.sharded_param import ShardedParam
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
get_gradient_predivide_factor)
class ShardedModelV2(nn.Module):
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
):
def __init__(self,
module: nn.Module,
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
shard_param: bool = True):
r"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
@ -44,22 +45,24 @@ class ShardedModelV2(nn.Module):
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
# The module has to be placed on GPU
self.module = module.cuda()
# Cast module to fp16 and cuda, in case user didn't use ZeroInitContext
self.module = module.half().cuda()
# Shard the parameters at first
for _, param in self.module.named_parameters():
param.ca_attr = ShardedParam(param)
param.ca_attr.shard()
param._sharded_grad = ShardedGradient(param, self, offload_config)
self.shard_strategy = shard_strategy
self.shard_param = shard_param
# In case user didn't use ZeroInitContext
for param in self.module.parameters():
if not hasattr(param, 'col_attr'):
param.col_attr = ShardedParamV2(param, process_group)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.data])
# Register hooks
register_ophooks_recursively(self.module, [ShardParamHook(), ShardGradHook()])
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy)])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.reshard_after_forward = reshard_after_forward
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
@ -76,6 +79,7 @@ class ShardedModelV2(nn.Module):
self._require_backward_grad_sync: bool = True
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs)
return outputs
@ -99,6 +103,7 @@ class ShardedModelV2(nn.Module):
torch.cuda.current_stream().synchronize()
self.reducer.free()
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
@ -107,11 +112,14 @@ class ShardedModelV2(nn.Module):
# sync passes, if desired.
if not self._require_backward_grad_sync:
continue
p._sharded_grad.write_back()
# Write grad back to p.grad and set p.col_attr.grad to None
p.grad.data = p.col_attr.grad
p.col_attr.grad = None
# In case some post bwd hook is not fired
for p in self.module.parameters():
if not p.ca_attr.is_sharded:
p.ca_attr.shard()
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@ -119,7 +127,7 @@ class ShardedModelV2(nn.Module):
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
@ -131,7 +139,7 @@ class ShardedModelV2(nn.Module):
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param._sharded_grad`, which ensures that
alignment is created by `param.col_attr.grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if grad is None:
@ -142,7 +150,7 @@ class ShardedModelV2(nn.Module):
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
new_grad = grad.clone()
if self.mixed_precision and self.fp32_reduce_scatter:
if self.fp32_reduce_scatter:
new_grad.data = new_grad.data.to(param.dtype)
if self.gradient_predivide_factor > 1.0:
# Average grad by world_size for consistency with PyTorch DDP.
@ -161,13 +169,30 @@ class ShardedModelV2(nn.Module):
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the cpu offload step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
orig_param_grad_data = reduced_grad.data
reduced_grad.data = reduced_grad.data.to(dtype=param.ca_attr.origin_dtype)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream())
param._sharded_grad.reduce_scatter_callback(reduced_grad)
# Make sure we store fp32 grad
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
# Maybe offload
if self._cpu_offload:
reduced_grad.data = reduced_grad.data.cpu()
if param.col_attr.grad is None:
param.col_attr.grad = reduced_grad.data
else:
param.col_attr.grad.add_(reduced_grad.data)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])
prev_params = {}
for p in self.module.parameters():
prev_params[p] = p.data
p.data = p.col_attr.data.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()])
for p in self.module.parameters():
p.data = prev_params[p]
return gathered_state_dict
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
raise NotImplementedError

View File

@ -1,3 +1,5 @@
from typing import Optional, Tuple, Union
import numpy
import torch
import torch.distributed as dist
@ -5,7 +7,6 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union, Tuple, Optional
class ShardedParamV2(object):
@ -14,12 +15,8 @@ class ShardedParamV2(object):
param: torch.nn.Parameter,
process_group: Optional[dist.ProcessGroup] = None,
rm_torch_payload=False) -> None:
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
if param.requires_grad and param.grad is not None:
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
param.grad = None
else:
self._grad_sharded_tensor = None
self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
self._grad_sharded_tensor: Optional[torch.Tensor] = None
# make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model.
@ -30,27 +27,29 @@ class ShardedParamV2(object):
if rm_torch_payload:
self.remove_torch_payload()
# Backward count for handle local grad accumulation
# This value will increment by 1 in every pre-bwd hook
# And will be reset to 0 in every final-bwd hook
self.bwd_count = 0
def remove_torch_payload(self):
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
@property
def data(self):
return self._data_sharded_tensor.payload
@data.setter
def data(self, t: torch.Tensor):
self._data_sharded_tensor.payload = t
return self._data_sharded_tensor
@property
def grad(self):
if self._grad_sharded_tensor:
return self._grad_sharded_tensor.payload
else:
return None
return self._grad_sharded_tensor
@grad.setter
def grad(self, t: torch.Tensor):
self._grad_sharded_tensor.payload = t
self._grad_sharded_tensor = t
@property
def param_is_sharded(self):
return self._data_sharded_tensor.is_sharded
class ShardedParam(object):

0
tests/__init__.py Normal file
View File

View File

@ -45,16 +45,16 @@ class Net(nn.Module):
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def check_grads(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
assert p.grad.dtype == zero_grad.dtype
assert allclose(p.grad, zero_grad, loose=loose)
LOGGER.info(torch.sum(p.grad - zero_grad))
grad = p.grad.float()
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params(model, zero_model, loose=False):
@ -71,11 +71,11 @@ def check_grads_padding(model, zero_model, loose=False):
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
grad = chunks[rank]
grad = chunks[rank].float()
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
assert allclose(grad, zero_grad, loose=loose), f'{grad} vs {zero_grad}'
def check_params_padding(model, zero_model, loose=False):

View File

@ -7,12 +7,14 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.init_ctx import ZeroInitContext
from common import CONFIG
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG, Net
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -25,11 +27,11 @@ def run_dist(rank, world_size, port):
shard_param=True):
model = model_builder(checkpoint=True)
for param in model.parameters():
assert hasattr(param, 'ca_attr')
assert param.ca_attr.data.dtype == torch.half
assert param.ca_attr._data_sharded_tensor.is_sharded
assert param.ca_attr.data.device.type == 'cuda'
for param in model.parameters():
assert hasattr(param, 'col_attr')
assert param.col_attr.data.dtype == torch.half
assert param.col_attr.data.is_sharded
assert param.col_attr.data.payload.device.type == 'cuda'
@pytest.mark.dist

View File

@ -9,19 +9,21 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, Net, check_grads, check_grads_padding
from common import CONFIG, check_grads, check_grads_padding
def run_fwd_bwd(model, x, enable_autocast=False):
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
y = model(data)
loss = criterion(y, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
@ -31,19 +33,26 @@ def run_fwd_bwd(model, x, enable_autocast=False):
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = Net(checkpoint=True).cuda()
zero_model = copy.deepcopy(model)
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
for _ in range(2):
x = torch.rand(2, 5).cuda()
run_fwd_bwd(zero_model, x, False)
run_fwd_bwd(model, x, False)
test_models = ['repeated_computed_layers', 'resnet18']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model().half().cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model)
else:
check_grads(model, zero_model)
model = DDP(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model, loose=True)
else:
check_grads(model, zero_model, loose=True)
@pytest.mark.dist

View File

@ -4,18 +4,16 @@
from copy import deepcopy
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG, allclose
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero_data_parallel.common import CONFIG, Net, allclose
def _run_shard_tensor(rank, world_size, port):
@ -47,7 +45,7 @@ def _run_shard_param_v2(rank, world_size, port):
param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param, process_group=None)
allclose(sparam.data, param_ref.data)
allclose(sparam.data.payload, param_ref.data)
sparam.remove_torch_payload()
assert (param.data.numel() == 1)

View File

@ -0,0 +1,73 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads, check_grads_padding
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(data)
loss = criterion(y, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True):
zero_model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
zero_model = zero_model()
model = copy.deepcopy(zero_model)
zero_model = ShardedModelV2(zero_model, shard_strategy)
model_state_dict = zero_model.state_dict()
for n, p in model.named_parameters():
p.data = model_state_dict[n]
model = model.half().cuda()
if dist.get_world_size() > 1:
model = DDP(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model, loose=True)
else:
check_grads(model, zero_model, loose=True)
@pytest.mark.dist
def test_shard_model_v2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2()

View File

@ -56,7 +56,7 @@ def run_dist(rank, world_size, port):
check_params(model, zero_model)
@pytest.mark.dist
@pytest.mark.skip
def test_sharded_optim_v2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from copy import deepcopy
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model()
shard_strategy = TensorShardStrategy()
model = model.half().cuda()
zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
zero_state_dict = zero_model.state_dict()
for key, val in model.state_dict().items():
assert torch.equal(val, zero_state_dict[key])
@pytest.mark.dist
def test_zero_state_dict():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_state_dict()