[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -25,11 +25,13 @@ class RingQK(torch.autograd.Function):
ctx.sub_seq_length = sub_seq_length
# create local segment of attention score
attention_score = torch.empty(batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device())
attention_score = torch.empty(
batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device(),
)
# compute local QK^T
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
@@ -51,7 +53,10 @@ class RingQK(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
sub_q, sub_k, = ctx.saved_tensors
(
sub_q,
sub_k,
) = ctx.saved_tensors
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
@@ -59,7 +64,7 @@ class RingQK(torch.autograd.Function):
grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q)
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length]
grad_k = grad_k[:, local_rank * ctx.sub_seq_length : (local_rank + 1) * ctx.sub_seq_length]
grad_k /= local_world_size
# calculate gradient for sub_q
@@ -96,11 +101,13 @@ class RingAV(torch.autograd.Function):
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)
sub_attention_result = torch.zeros(batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype)
sub_attention_result = torch.zeros(
batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype,
)
# save tensors for backward
ctx.save_for_backward(attention_score, sub_v)