mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Hotfix] Fix bugs in testing continuous batching (#5270)
* fix bug * fix bugs * fix bugs * fix bugs and add padding * add funcs and fix bugs * fix typos * fix bugs * add func
This commit is contained in:
@@ -29,6 +29,9 @@ class RequestStatus(enum.Enum):
|
||||
COMPLETED = enum.auto()
|
||||
LENGTH_CAPPED = enum.auto()
|
||||
|
||||
# recycle status
|
||||
RECYCLED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "RequestStatus") -> bool:
|
||||
return status in [
|
||||
@@ -119,7 +122,9 @@ class Sequence:
|
||||
"""
|
||||
Set status for prefill reqs.
|
||||
"""
|
||||
assert self.status == RequestStatus.WAITING, "Sequence is not in WAITTING STATUS"
|
||||
assert (
|
||||
self.status == RequestStatus.WAITING or RequestStatus.RECYCLED
|
||||
), "Sequence is not in WAITTING/RECYCLED STATUS"
|
||||
self.status = RequestStatus.RUNNING
|
||||
|
||||
def mark_finished(self) -> None:
|
||||
@@ -139,10 +144,10 @@ class Sequence:
|
||||
Recycle a running sequnce to waiitting list
|
||||
"""
|
||||
assert (
|
||||
not self.status.is_finished and not self.status == RequestStatus.ABORTED
|
||||
not self.check_finish() and not self.status == RequestStatus.ABORTED
|
||||
), "The running sequence \
|
||||
is already done but it still in running list"
|
||||
self.status = RequestStatus.WAITING
|
||||
self.status = RequestStatus.RECYCLED
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@@ -162,7 +167,7 @@ class BatchInfo:
|
||||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
sequences_set: OrderedSet["Sequence"] = None
|
||||
sequences_set: OrderedSet[Sequence] = None
|
||||
is_prompts: bool = True
|
||||
device: torch.device = None
|
||||
|
||||
@@ -207,12 +212,20 @@ class BatchInfo:
|
||||
|
||||
def clear_batch(self) -> None:
|
||||
"""
|
||||
Clear sequence set and block table.
|
||||
Clear sequence set and block table if we need to abort this batch.
|
||||
Prefill: clear sequence set and move them to running batch(external)
|
||||
Decoding: mark unfinished sequences as aborted.
|
||||
"""
|
||||
for seq in self.sequences_set:
|
||||
if not seq.check_finish():
|
||||
seq.status = RequestStatus.ABORTED
|
||||
self.sequences_set.clear()
|
||||
if self.is_prompts:
|
||||
self.sequences_set.clear()
|
||||
|
||||
else:
|
||||
for seq in self.sequences_set:
|
||||
seq.mark_aborted()
|
||||
if seq.check_finish():
|
||||
seq.mark_finished()
|
||||
|
||||
self.sequences_set.clear()
|
||||
|
||||
def fliter_batch(self) -> List["Sequence"]:
|
||||
"""
|
||||
@@ -255,6 +268,12 @@ class BatchInfo:
|
||||
continue
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
def del_seq(self, seq: Sequence) -> Sequence:
|
||||
"""
|
||||
Delete sequence in batch
|
||||
"""
|
||||
self.sequences_set.discard(seq)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> None:
|
||||
"""
|
||||
@@ -297,11 +316,19 @@ class BatchInfo:
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.append(seq.input_token_id)
|
||||
if seq.output_len > 0:
|
||||
print(seq.output_token_id)
|
||||
seq_data = seq.input_token_id + seq.output_token_id
|
||||
print(seq_data)
|
||||
input_list.append(seq.input_token_id + seq.output_token_id)
|
||||
else:
|
||||
input_list.append(seq.input_token_id)
|
||||
else:
|
||||
input_list.append([seq.output_token_id[-1]])
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
||||
max_seq_len = max(len(sub_list) for sub_list in input_list)
|
||||
|
||||
return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int)
|
||||
|
||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -340,12 +367,27 @@ class BatchInfo:
|
||||
for seq in self.sequences_set:
|
||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
|
||||
max_seq_len = max(len(sub_list) for sub_list in past_values)
|
||||
attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device)
|
||||
|
||||
if torch.any(attn_mask == 0):
|
||||
return attn_mask
|
||||
else:
|
||||
return None
|
||||
return attn_mask.ne(padding_id).long()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
assert len(x) <= max_len
|
||||
return x + [pad] * (max_len - len(x))
|
||||
|
||||
|
||||
def _make_tensor_with_pad(
|
||||
x: Union[List[List[int]], List[int]],
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")
|
||||
|
Reference in New Issue
Block a user