mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support
This commit is contained in:
@@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
@@ -98,7 +98,46 @@ class RunningList:
|
||||
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
class NaiveRequestHandler:
|
||||
def __init__(self) -> None:
|
||||
self.running_list: List[DiffusionSequence] = []
|
||||
self.waiting_list: List[str] = []
|
||||
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def _has_running(self) -> bool:
|
||||
return any(lst for lst in self.running_list)
|
||||
|
||||
def check_unfinished_reqs(self):
|
||||
return self._has_waiting() or self._has_running()
|
||||
|
||||
def add_sequence(self, seq: DiffusionSequence):
|
||||
"""
|
||||
Add the request to waiting list.
|
||||
"""
|
||||
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
|
||||
self.waiting_list.append(seq)
|
||||
|
||||
def _find_sequence(self, request_id: int) -> DiffusionSequence:
|
||||
"""
|
||||
Find the request by request_id.
|
||||
"""
|
||||
for lst in enumerate(self.waiting_list + self.running_list):
|
||||
for seq in lst:
|
||||
if seq.request_id == request_id:
|
||||
return seq
|
||||
return None
|
||||
|
||||
def schedule(self):
|
||||
ret = None
|
||||
if self._has_waiting:
|
||||
ret = self.waiting_list[0]
|
||||
self.waiting_list = self.waiting_list[1:]
|
||||
return ret
|
||||
|
||||
|
||||
class RequestHandler(NaiveRequestHandler):
|
||||
"""
|
||||
RequestHandler is the core for handling existing requests and updating current batch.
|
||||
During generation process, we call schedule function each iteration to update current batch.
|
||||
@@ -176,12 +215,12 @@ class RequestHandler:
|
||||
generated_token_size=inference_config.generated_token_size,
|
||||
)
|
||||
|
||||
def _has_running(self) -> bool:
|
||||
return not self.running_bb.is_empty()
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self.cache_manager.get_kv_cache()
|
||||
|
||||
@@ -318,7 +357,7 @@ class RequestHandler:
|
||||
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
|
||||
seq.mark_finished()
|
||||
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
def check_unfinished_reqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
|
Reference in New Issue
Block a user