[shardformer] supported fused qkv checkpoint (#4073)

This commit is contained in:
Frank Lee
2023-06-23 16:07:09 +08:00
parent 0803a61412
commit 70c58cfd4f
10 changed files with 420 additions and 88 deletions

View File

@@ -305,3 +305,130 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.sharding_spec
# ======================================================
# Some sharding does not obey the SPMD style
# e.g. Fused QKV layer in GPT2
# we support customize sharding with the following APIs
# ======================================================
def is_customized_distributed_tensor(tensor: torch.Tensor):
"""
Check whether the given tensor is a customized distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a customized distributed tensor.
"""
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), 'The shard_fn must be callable.'
assert callable(gather_fn), 'The gather_fn must be callable.'
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
sharded_tensor = shard_fn(tensor)
# set the shard_fn and gather_fn as attributes of the distributed tensor
sharded_tensor.shard_fn = shard_fn
sharded_tensor.gather_fn = gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
return sharded_tensor
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Gather the given tensor to the global tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Tensor: The global tensor.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
return dtensor.gather_fn(dtensor)
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
"""
Convert the given customized distributed tensor to a parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
return param