mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
add paged-attetionv2: support seq length split across thread block (#5707)
This commit is contained in:
@@ -16,6 +16,8 @@ class FDIntermTensors(metaclass=SingletonMeta):
|
||||
self._tensors_initialized = False
|
||||
del self._mid_output
|
||||
del self._mid_output_lse
|
||||
del self._exp_sums
|
||||
del self._max_logits
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
@@ -31,6 +33,16 @@ class FDIntermTensors(metaclass=SingletonMeta):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output_lse
|
||||
|
||||
@property
|
||||
def exp_sums(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._exp_sums
|
||||
|
||||
@property
|
||||
def max_logits(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._max_logits
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
max_batch_size: int,
|
||||
@@ -60,5 +72,11 @@ class FDIntermTensors(metaclass=SingletonMeta):
|
||||
self._mid_output_lse = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
|
||||
)
|
||||
self._exp_sums = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
|
||||
)
|
||||
self._max_logits = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self._tensors_initialized = True
|
||||
|
Reference in New Issue
Block a user