mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[shardformer] supported fused qkv checkpoint (#4073)
This commit is contained in:
@@ -305,3 +305,130 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout.sharding_spec
|
||||
|
||||
|
||||
# ======================================================
|
||||
# Some sharding does not obey the SPMD style
|
||||
# e.g. Fused QKV layer in GPT2
|
||||
# we support customize sharding with the following APIs
|
||||
# ======================================================
|
||||
def is_customized_distributed_tensor(tensor: torch.Tensor):
|
||||
"""
|
||||
Check whether the given tensor is a customized distributed tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the given tensor is a customized distributed tensor.
|
||||
"""
|
||||
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
|
||||
|
||||
|
||||
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be hijacked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The hijacked tensor.
|
||||
"""
|
||||
dtensor._old_detach = dtensor.detach
|
||||
dtensor._old_clone = dtensor.clone
|
||||
|
||||
def new_detach(self):
|
||||
t_ = self._old_detach()
|
||||
t_.shard_fn = self.shard_fn
|
||||
t_.gather_fn = self.gather_fn
|
||||
return t_
|
||||
|
||||
def new_clone(self, *args, **kwargs):
|
||||
t_ = self._old_clone(*args, **kwargs)
|
||||
t_.shard_fn = self.shard_fn
|
||||
t_.gather_fn = self.gather_fn
|
||||
return t_
|
||||
|
||||
# bind the new methods to the tensor
|
||||
dtensor.detach = new_detach.__get__(dtensor)
|
||||
dtensor.clone = new_clone.__get__(dtensor)
|
||||
return dtensor
|
||||
|
||||
|
||||
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
|
||||
"""
|
||||
Distribute the given tensor with the given shard_fn and gather_fn.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# define shard and gather functions
|
||||
def shard_fn(tensor):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
return tensor.chunk(world_size, dim=0)[rank]
|
||||
|
||||
def gather_fn(tensor):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(shard_list, tensor)
|
||||
return torch.cat(shard_list, dim=0)
|
||||
|
||||
# create a distributed tensor
|
||||
tensor = torch.rand(4, 4)
|
||||
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
|
||||
```
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be distributed.
|
||||
shard_fn (callable): The function to shard the tensor.
|
||||
gather_fn (callable): The function to gather the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor.
|
||||
"""
|
||||
assert callable(shard_fn), 'The shard_fn must be callable.'
|
||||
assert callable(gather_fn), 'The gather_fn must be callable.'
|
||||
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||
|
||||
sharded_tensor = shard_fn(tensor)
|
||||
|
||||
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
||||
sharded_tensor.shard_fn = shard_fn
|
||||
sharded_tensor.gather_fn = gather_fn
|
||||
|
||||
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
||||
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
|
||||
|
||||
return sharded_tensor
|
||||
|
||||
|
||||
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Gather the given tensor to the global tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The global tensor.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
return dtensor.gather_fn(dtensor)
|
||||
|
||||
|
||||
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
"""
|
||||
Convert the given customized distributed tensor to a parameter.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
|
||||
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||
|
||||
# make it distributed as well
|
||||
param.shard_fn = dtensor.shard_fn
|
||||
param.gather_fn = dtensor.gather_fn
|
||||
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|
||||
return param
|
||||
|
Reference in New Issue
Block a user