mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -235,6 +235,14 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
return param
|
||||
|
||||
|
||||
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
param.data = dtensor
|
||||
# make it distributed as well
|
||||
param.dist_layout = dtensor.dist_layout
|
||||
_hijack_detach_and_clone(param)
|
||||
|
||||
|
||||
def compute_global_numel(dtensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Compute the global number of elements in the distributed tensor.
|
||||
@@ -432,3 +440,15 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad:
|
||||
param.gather_fn = dtensor.gather_fn
|
||||
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|
||||
return param
|
||||
|
||||
|
||||
def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
|
||||
"""
|
||||
Convert the given customized distributed tensor to an existing parameter.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
|
||||
param.data = dtensor.data
|
||||
param.shard_fn = dtensor.shard_fn
|
||||
param.gather_fn = dtensor.gather_fn
|
||||
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|
||||
|
Reference in New Issue
Block a user