add paged-attetionv2: support seq length split across thread block (#5707)

This commit is contained in:
Steve Luo
2024-05-14 12:46:54 +08:00
committed by GitHub
parent 18d67d0e8e
commit 7806842f2d
8 changed files with 704 additions and 249 deletions

View File

@@ -16,6 +16,8 @@ class FDIntermTensors(metaclass=SingletonMeta):
self._tensors_initialized = False
del self._mid_output
del self._mid_output_lse
del self._exp_sums
del self._max_logits
@property
def is_initialized(self):
@@ -31,6 +33,16 @@ class FDIntermTensors(metaclass=SingletonMeta):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output_lse
@property
def exp_sums(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._exp_sums
@property
def max_logits(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._max_logits
def initialize(
self,
max_batch_size: int,
@@ -60,5 +72,11 @@ class FDIntermTensors(metaclass=SingletonMeta):
self._mid_output_lse = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._exp_sums = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._max_logits = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._tensors_initialized = True