[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support
This commit is contained in:
Runyu Lu
2024-07-08 16:02:07 +08:00
committed by GitHub
parent 8ec24b6a4d
commit cba20525a8
16 changed files with 1860 additions and 740 deletions

View File

@@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@@ -98,7 +98,46 @@ class RunningList:
self._decoding[seq_id] = self._prefill.pop(seq_id)
class RequestHandler:
class NaiveRequestHandler:
def __init__(self) -> None:
self.running_list: List[DiffusionSequence] = []
self.waiting_list: List[str] = []
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def _has_running(self) -> bool:
return any(lst for lst in self.running_list)
def check_unfinished_reqs(self):
return self._has_waiting() or self._has_running()
def add_sequence(self, seq: DiffusionSequence):
"""
Add the request to waiting list.
"""
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
self.waiting_list.append(seq)
def _find_sequence(self, request_id: int) -> DiffusionSequence:
"""
Find the request by request_id.
"""
for lst in enumerate(self.waiting_list + self.running_list):
for seq in lst:
if seq.request_id == request_id:
return seq
return None
def schedule(self):
ret = None
if self._has_waiting:
ret = self.waiting_list[0]
self.waiting_list = self.waiting_list[1:]
return ret
class RequestHandler(NaiveRequestHandler):
"""
RequestHandler is the core for handling existing requests and updating current batch.
During generation process, we call schedule function each iteration to update current batch.
@@ -176,12 +215,12 @@ class RequestHandler:
generated_token_size=inference_config.generated_token_size,
)
def _has_running(self) -> bool:
return not self.running_bb.is_empty()
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
@@ -318,7 +357,7 @@ class RequestHandler:
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()
def check_unfinished_seqs(self) -> bool:
def check_unfinished_reqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def total_requests_in_batch_bucket(self) -> int: