[inference] removed redundancy init_batch (#5353)

This commit is contained in:
Frank Lee
2024-02-02 11:44:15 +08:00
committed by GitHub
parent 249644c23b
commit db1a763307
3 changed files with 6 additions and 25 deletions

View File

@@ -188,24 +188,6 @@ class BatchInfo:
if self.fd_inter_tensor is None:
self.fd_inter_tensor = FDIntermTensors()
def init_batch(self, seqs: List["Sequence"] = None):
"""
Initializes inference batches by input sentence list.
Args:
seqs (List["Sequence"]): List of input sequence.
"""
if seqs is not None:
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
self.sequences_set.add(seq)
def init_fd_tensors(self):
if not self.fd_inter_tensor.is_initialized:
self.fd_inter_tensor.initialize(
@@ -273,19 +255,19 @@ class BatchInfo:
self.sequences_set.discard(seq)
return seq
def add_seqs(self, seqs: List["Sequence"]) -> None:
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
"""
Add new sequence to batch
Args:
seqs (List["Sequence"]): The list of new sequences.
"""
if not isinstance(seqs, list):
# covnert single sequence to list
if isinstance(seqs, Sequence):
seqs = [seqs]
for seq in seqs:
if self.sequences_set and seq in self.sequences_set:
if seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
self.sequences_set.add(seq)