mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -25,11 +25,13 @@ class RingQK(torch.autograd.Function):
|
||||
ctx.sub_seq_length = sub_seq_length
|
||||
|
||||
# create local segment of attention score
|
||||
attention_score = torch.empty(batch_size * num_attention_heads,
|
||||
sub_seq_length,
|
||||
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
|
||||
dtype=sub_q.dtype,
|
||||
device=get_current_device())
|
||||
attention_score = torch.empty(
|
||||
batch_size * num_attention_heads,
|
||||
sub_seq_length,
|
||||
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
|
||||
dtype=sub_q.dtype,
|
||||
device=get_current_device(),
|
||||
)
|
||||
|
||||
# compute local QK^T
|
||||
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
|
||||
@@ -51,7 +53,10 @@ class RingQK(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
sub_q, sub_k, = ctx.saved_tensors
|
||||
(
|
||||
sub_q,
|
||||
sub_k,
|
||||
) = ctx.saved_tensors
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
|
||||
@@ -59,7 +64,7 @@ class RingQK(torch.autograd.Function):
|
||||
grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q)
|
||||
|
||||
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length]
|
||||
grad_k = grad_k[:, local_rank * ctx.sub_seq_length : (local_rank + 1) * ctx.sub_seq_length]
|
||||
grad_k /= local_world_size
|
||||
|
||||
# calculate gradient for sub_q
|
||||
@@ -96,11 +101,13 @@ class RingAV(torch.autograd.Function):
|
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)
|
||||
|
||||
sub_attention_result = torch.zeros(batch_size * num_attention_heads,
|
||||
sub_seq_length,
|
||||
attention_head_size,
|
||||
device=get_current_device(),
|
||||
dtype=attention_score.dtype)
|
||||
sub_attention_result = torch.zeros(
|
||||
batch_size * num_attention_heads,
|
||||
sub_seq_length,
|
||||
attention_head_size,
|
||||
device=get_current_device(),
|
||||
dtype=attention_score.dtype,
|
||||
)
|
||||
|
||||
# save tensors for backward
|
||||
ctx.save_for_backward(attention_score, sub_v)
|
||||
|
Reference in New Issue
Block a user