mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support
This commit is contained in:
parent
8ec24b6a4d
commit
cba20525a8
@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
@ -396,3 +396,49 @@ class ModelShardInferenceConfig:
|
|||||||
use_cuda_kernel: bool = False
|
use_cuda_kernel: bool = False
|
||||||
use_spec_dec: bool = False
|
use_spec_dec: bool = False
|
||||||
use_flash_attn: bool = False
|
use_flash_attn: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiffusionGenerationConfig:
|
||||||
|
"""
|
||||||
|
Param for diffusion model forward
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_2: Optional[Union[str, List[str]]] = None
|
||||||
|
prompt_3: Optional[Union[str, List[str]]] = None
|
||||||
|
height: Optional[int] = None
|
||||||
|
width: Optional[int] = None
|
||||||
|
num_inference_steps: int = None
|
||||||
|
timesteps: List[int] = None
|
||||||
|
guidance_scale: float = None
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = (
|
||||||
|
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
|
||||||
|
)
|
||||||
|
negative_prompt_2: Optional[Union[str, List[str]]] = None
|
||||||
|
negative_prompt_3: Optional[Union[str, List[str]]] = None
|
||||||
|
num_images_per_prompt: Optional[int] = None
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
|
||||||
|
latents: Optional[torch.FloatTensor] = None
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
output_type: Optional[str] = None # "pil"
|
||||||
|
return_dict: bool = None
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
clip_skip: Optional[int] = None
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
# NOTE(@lry89757) Only return the dict that not the default value None
|
||||||
|
result = {}
|
||||||
|
for field in fields(self):
|
||||||
|
value = getattr(self, field.name)
|
||||||
|
if value is not None:
|
||||||
|
result[field.name] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
|
||||||
|
return cls(**kwargs)
|
||||||
|
90
colossalai/inference/core/base_engine.py
Normal file
90
colossalai/inference/core/base_engine.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.inference.config import ModelShardInferenceConfig
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEngine(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
|
||||||
|
"""
|
||||||
|
Init Model for Engine
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Generate ouptput for coming requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_request(self, prompts, request_ids=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Add new request to Engine
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def step(self):
|
||||||
|
"""
|
||||||
|
Perform one new step forward
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _verify_args(self):
|
||||||
|
"""
|
||||||
|
Verify the parameters and members of class
|
||||||
|
"""
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def capture_model(self):
|
||||||
|
"""
|
||||||
|
Use cuda graph to capture model
|
||||||
|
"""
|
||||||
|
return NotImplementedError("This method should be implemented by subclasses")
|
||||||
|
|
||||||
|
def _shardformer(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
model_policy: Policy,
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||||
|
stage_manager: PipelineStageManager = None,
|
||||||
|
tp_group: ProcessGroupMesh = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Initialize ShardConfig and replace the model with shardformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Path or nn.Module of this model.
|
||||||
|
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||||
|
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||||
|
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The model optimized by Shardformer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shardconfig = ShardConfig(
|
||||||
|
tensor_parallel_process_group=tp_group,
|
||||||
|
pipeline_stage_manager=stage_manager,
|
||||||
|
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||||
|
enable_fused_normalization=False,
|
||||||
|
enable_all_optimization=False,
|
||||||
|
enable_flash_attention=False,
|
||||||
|
enable_jit_fused=False,
|
||||||
|
enable_sequence_parallelism=False,
|
||||||
|
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
|
||||||
|
)
|
||||||
|
shardformer = ShardFormer(shard_config=shardconfig)
|
||||||
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||||
|
return shard_model
|
200
colossalai/inference/core/diffusion_engine.py
Normal file
200
colossalai/inference/core/diffusion_engine.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
from itertools import count
|
||||||
|
from typing import List, Tuple, Type, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
|
||||||
|
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||||
|
from colossalai.inference.modeling.policy import model_policy_map
|
||||||
|
from colossalai.inference.struct import DiffusionSequence
|
||||||
|
from colossalai.inference.utils import get_model_size, get_model_type
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
|
from .base_engine import BaseEngine
|
||||||
|
from .request_handler import NaiveRequestHandler
|
||||||
|
|
||||||
|
PP_AXIS, TP_AXIS = 0, 1
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_or_path: DiffusionPipeline | str,
|
||||||
|
inference_config: InferenceConfig = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
model_policy: Policy | type[Policy] = None,
|
||||||
|
) -> None:
|
||||||
|
self.inference_config = inference_config
|
||||||
|
self.dtype = inference_config.dtype
|
||||||
|
self.high_precision = inference_config.high_precision
|
||||||
|
|
||||||
|
self.verbose = verbose
|
||||||
|
self.logger = get_dist_logger(__name__)
|
||||||
|
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||||
|
|
||||||
|
self.model_type = get_model_type(model_or_path=model_or_path)
|
||||||
|
|
||||||
|
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||||
|
|
||||||
|
self.request_handler = NaiveRequestHandler()
|
||||||
|
|
||||||
|
self.counter = count()
|
||||||
|
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
|
||||||
|
|
||||||
|
def init_model(
|
||||||
|
self,
|
||||||
|
model_or_path: Union[str, nn.Module, DiffusionPipeline],
|
||||||
|
model_policy: Union[Policy, Type[Policy]] = None,
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Shard model or/and Load weight
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||||
|
model_policy (Policy): the policy to replace the model.
|
||||||
|
model_inference_config: the configuration for modeling initialization when inference.
|
||||||
|
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||||
|
"""
|
||||||
|
if isinstance(model_or_path, str):
|
||||||
|
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
|
||||||
|
policy_map_key = model.__class__.__name__
|
||||||
|
model = DiffusionPipe(model)
|
||||||
|
elif isinstance(model_or_path, DiffusionPipeline):
|
||||||
|
policy_map_key = model_or_path.__class__.__name__
|
||||||
|
model = DiffusionPipe(model_or_path)
|
||||||
|
else:
|
||||||
|
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
|
self.device = get_accelerator().get_current_device()
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(f"the device is {self.device}")
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(
|
||||||
|
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_policy is None:
|
||||||
|
model_policy = model_policy_map.get(policy_map_key)
|
||||||
|
|
||||||
|
if not isinstance(model_policy, Policy):
|
||||||
|
try:
|
||||||
|
model_policy = model_policy()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||||
|
|
||||||
|
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||||
|
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||||
|
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
|
|
||||||
|
self.model = self._shardformer(
|
||||||
|
model,
|
||||||
|
model_policy,
|
||||||
|
model_shard_infer_config,
|
||||||
|
None,
|
||||||
|
tp_group=tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = model.to(self.device)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(
|
||||||
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||||
|
peak_memory = init_gpu_memory - free_gpu_memory
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(
|
||||||
|
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
request_ids: Union[List[int], int] = None,
|
||||||
|
prompts: Union[List[str], str] = None,
|
||||||
|
generation_config: DiffusionGenerationConfig = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
|
||||||
|
""" """
|
||||||
|
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||||
|
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||||
|
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
if prompts is not None:
|
||||||
|
self.add_request(
|
||||||
|
request_ids=request_ids,
|
||||||
|
prompts=prompts,
|
||||||
|
**gen_config_dict,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_reqs_list = []
|
||||||
|
|
||||||
|
# intuition: If user provide a generation config, we should replace the existing one.
|
||||||
|
if generation_config is not None:
|
||||||
|
self.generation_config = generation_config
|
||||||
|
self.generation_config_dict = gen_config_dict
|
||||||
|
|
||||||
|
while self.request_handler.check_unfinished_reqs():
|
||||||
|
output_reqs_list += self.step()
|
||||||
|
|
||||||
|
return output_reqs_list
|
||||||
|
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
prompts: Union[List[str], str],
|
||||||
|
request_ids: Union[List[int], int] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if request_ids is not None and not isinstance(request_ids, list):
|
||||||
|
request_ids = [request_ids]
|
||||||
|
|
||||||
|
if not isinstance(prompts, list):
|
||||||
|
prompts = [prompts]
|
||||||
|
|
||||||
|
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
|
||||||
|
prompts_num = len(prompts)
|
||||||
|
for i in range(prompts_num):
|
||||||
|
if request_ids:
|
||||||
|
assert isinstance(
|
||||||
|
request_ids[0], int
|
||||||
|
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||||
|
assert len(request_ids) == prompts_num
|
||||||
|
request_id = request_ids[i]
|
||||||
|
else:
|
||||||
|
request_id = next(self.counter)
|
||||||
|
|
||||||
|
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
|
||||||
|
|
||||||
|
self.request_handler.add_sequence(seq)
|
||||||
|
|
||||||
|
def step(self) -> List[PIL.Image.Image]:
|
||||||
|
"""
|
||||||
|
In each step, do the follows:
|
||||||
|
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||||
|
2. run forward to get List[Image]
|
||||||
|
Returns:
|
||||||
|
List[PIL.Image.Image]: Image Generated by one step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input = self.request_handler.schedule()
|
||||||
|
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
|
||||||
|
return ret
|
@ -1,57 +1,24 @@
|
|||||||
import time
|
from typing import List, Tuple, Type, Union
|
||||||
from itertools import count
|
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import PIL.Image
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import distributed as dist
|
from diffusers import DiffusionPipeline
|
||||||
from transformers import (
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
GenerationConfig,
|
|
||||||
PreTrainedTokenizer,
|
|
||||||
PreTrainedTokenizerFast,
|
|
||||||
)
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.inference.utils import ModelType, get_model_type
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
|
||||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
|
||||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
|
||||||
from colossalai.inference.modeling.policy import model_policy_map
|
|
||||||
from colossalai.inference.sampler import search_tokens
|
|
||||||
from colossalai.inference.spec import Drafter, GlideInput
|
|
||||||
from colossalai.inference.struct import Sequence
|
|
||||||
from colossalai.inference.utils import get_model_size, has_index_file
|
|
||||||
from colossalai.interface import ModelWrapper
|
|
||||||
from colossalai.lazy import LazyInitContext
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
from .request_handler import RequestHandler
|
|
||||||
|
|
||||||
__all__ = ["InferenceEngine"]
|
__all__ = ["InferenceEngine"]
|
||||||
|
|
||||||
PP_AXIS, TP_AXIS = 0, 1
|
|
||||||
|
|
||||||
_supported_models = {
|
|
||||||
"LlamaForCausalLM": LlamaForCausalLM,
|
|
||||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
|
||||||
}
|
|
||||||
|
|
||||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
"""
|
"""
|
||||||
InferenceEngine which manages the inference process..
|
InferenceEngine which manages the inference process..
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
|
||||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||||
verbose (bool): Determine whether or not to log the generation process.
|
verbose (bool): Determine whether or not to log the generation process.
|
||||||
@ -60,567 +27,68 @@ class InferenceEngine:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_or_path: Union[nn.Module, str],
|
model_or_path: Union[nn.Module, str, DiffusionPipeline],
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
||||||
inference_config: InferenceConfig,
|
inference_config: InferenceConfig = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
model_policy: Union[Policy, Type[Policy]] = None,
|
model_policy: Union[Policy, Type[Policy]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.inference_config = inference_config
|
self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
|
||||||
self.dtype = inference_config.dtype
|
self.model_type = get_model_type(model_or_path=model_or_path)
|
||||||
self.high_precision = inference_config.high_precision
|
self.engine = None
|
||||||
|
if self.model_type == ModelType.LLM:
|
||||||
|
from .llm_engine import LLMEngine
|
||||||
|
|
||||||
self.verbose = verbose
|
self.engine = LLMEngine(
|
||||||
self.logger = get_dist_logger(__name__)
|
model_or_path=model_or_path,
|
||||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
tokenizer=tokenizer,
|
||||||
|
inference_config=inference_config,
|
||||||
|
verbose=verbose,
|
||||||
|
model_policy=model_policy,
|
||||||
|
)
|
||||||
|
elif self.model_type == ModelType.DIFFUSION_MODEL:
|
||||||
|
from .diffusion_engine import DiffusionEngine
|
||||||
|
|
||||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
self.engine = DiffusionEngine(
|
||||||
|
model_or_path=model_or_path,
|
||||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
inference_config=inference_config,
|
||||||
self.generation_config_dict = self.generation_config.to_dict()
|
verbose=verbose,
|
||||||
|
model_policy=model_policy,
|
||||||
self.tokenizer = tokenizer
|
)
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
elif self.model_type == ModelType.UNKNOWN:
|
||||||
|
self.logger.error(f"Model Type either Difffusion or LLM!")
|
||||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
|
||||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
|
||||||
# DISCUSS maybe move this into batch info?
|
|
||||||
|
|
||||||
self.counter = count()
|
|
||||||
|
|
||||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
|
||||||
if self.use_cuda_graph:
|
|
||||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
|
||||||
self.graph_memory_pool = None # Set during graph capture.
|
|
||||||
if verbose:
|
|
||||||
self.logger.info("Colossal AI CUDA Graph Capture on")
|
|
||||||
|
|
||||||
self.capture_model(self.k_cache, self.v_cache)
|
|
||||||
|
|
||||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
|
||||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
|
||||||
|
|
||||||
self.drafter_model = None
|
|
||||||
self.drafter = None
|
|
||||||
self.use_glide = False
|
|
||||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def init_model(
|
|
||||||
self,
|
|
||||||
model_or_path: Union[nn.Module, str],
|
|
||||||
model_policy: Union[Policy, Type[Policy]] = None,
|
|
||||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Shard model or/and Load weight
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
|
||||||
model_policy (Policy): the policy to replace the model.
|
|
||||||
model_inference_config: the configuration for modeling initialization when inference.
|
|
||||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
|
||||||
"""
|
|
||||||
pretrained_path = None
|
|
||||||
if isinstance(model_or_path, str):
|
|
||||||
import colossalai.interface.pretrained as pretrained_utils
|
|
||||||
|
|
||||||
try:
|
|
||||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
|
||||||
arch = getattr(hf_config, "architectures")[0]
|
|
||||||
if arch in _supported_models.keys():
|
|
||||||
if arch is "BaichuanForCausalLM":
|
|
||||||
self.logger.warning(
|
|
||||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
|
||||||
)
|
|
||||||
ctx = LazyInitContext(default_device="cuda")
|
|
||||||
with ctx:
|
|
||||||
model = _supported_models[arch].from_pretrained(
|
|
||||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
|
||||||
)
|
|
||||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
|
||||||
else:
|
|
||||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
|
||||||
raise ValueError(f"Model {arch} is not supported.")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(
|
|
||||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = model_or_path
|
|
||||||
|
|
||||||
self.model_config = model.config
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
||||||
|
|
||||||
self.device = get_accelerator().get_current_device()
|
|
||||||
if self.verbose:
|
|
||||||
self.logger.info(f"the device is {self.device}")
|
|
||||||
|
|
||||||
model = model.to(self.dtype).eval()
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
self.logger.info(
|
|
||||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_policy is None:
|
|
||||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
|
||||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
|
||||||
model_policy = model_policy_map.get(model_policy_key)
|
|
||||||
|
|
||||||
if not isinstance(model_policy, Policy):
|
|
||||||
try:
|
|
||||||
model_policy = model_policy()
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
|
||||||
|
|
||||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
|
||||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
|
||||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
|
||||||
|
|
||||||
self.model = self._shardformer(
|
|
||||||
model,
|
|
||||||
model_policy,
|
|
||||||
model_shard_infer_config,
|
|
||||||
None,
|
|
||||||
tp_group=tp_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = ModelWrapper(model).to(self.device)
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
self.logger.info(
|
|
||||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pretrained_path:
|
|
||||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
|
||||||
|
|
||||||
cpt_io = InferCheckpoint_io()
|
|
||||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
|
||||||
assert if_has_index_file, "the model path is invalid"
|
|
||||||
cpt_io.load_model(self.model, model_index_file)
|
|
||||||
|
|
||||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
|
||||||
peak_memory = init_gpu_memory - free_gpu_memory
|
|
||||||
if self.verbose:
|
|
||||||
self.logger.info(
|
|
||||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
|
||||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
|
||||||
|
|
||||||
t_capture_begin = time.perf_counter()
|
|
||||||
|
|
||||||
block_size = self.inference_config.block_size
|
|
||||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
|
||||||
|
|
||||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
|
||||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
|
||||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
|
||||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
|
||||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
|
||||||
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
|
||||||
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
|
||||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
|
||||||
self.graph_block_tables[0, :] = np.arange(
|
|
||||||
0, max_num_blocks
|
|
||||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
|
||||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
|
||||||
output_tensor = torch.zeros(
|
|
||||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
|
||||||
)
|
|
||||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
|
||||||
|
|
||||||
max_num_seqs = self.inference_config.max_batch_size
|
|
||||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
|
||||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
|
||||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
|
||||||
sequence_lengths[0] = torch.tensor(
|
|
||||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
|
||||||
).cuda()
|
|
||||||
|
|
||||||
# NOTE: Capturing the largest batch size first may help reduce the
|
|
||||||
# memory usage of CUDA graph.
|
|
||||||
for batch_size in reversed(batch_size_capture_list):
|
|
||||||
if self.verbose:
|
|
||||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
|
||||||
|
|
||||||
input_meta_data = InputMetaData(
|
|
||||||
block_tables=block_tables[:batch_size],
|
|
||||||
sequence_lengths=sequence_lengths[:batch_size],
|
|
||||||
fd_inter_tensor=fd_inter_tensor,
|
|
||||||
batch_size=batch_size,
|
|
||||||
is_prompts=False,
|
|
||||||
use_cuda_graph=True,
|
|
||||||
high_precision=False,
|
|
||||||
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
|
||||||
head_dim=head_dim,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_runner = CUDAGraphRunner(self.model)
|
|
||||||
graph_runner.capture(
|
|
||||||
input_tokens_ids[:batch_size],
|
|
||||||
output_tensor[:batch_size],
|
|
||||||
input_meta_data,
|
|
||||||
k_caches=k_cache,
|
|
||||||
v_caches=v_cache,
|
|
||||||
memory_pool=self.graph_memory_pool,
|
|
||||||
)
|
|
||||||
self.graph_memory_pool = graph_runner.graph.pool()
|
|
||||||
self.graph_runners[batch_size] = graph_runner
|
|
||||||
|
|
||||||
t_capture_end = time.perf_counter()
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
"""Verify the input args"""
|
"""Verify the input args"""
|
||||||
if not isinstance(self.inference_config, InferenceConfig):
|
assert self.engine is not None, "Please init Engine first"
|
||||||
raise TypeError("Invalid type of inference config provided.")
|
assert self._initialized, "Engine must be initialized"
|
||||||
if not isinstance(self.model, nn.Module):
|
|
||||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
|
||||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
|
||||||
raise TypeError(
|
|
||||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
|
||||||
)
|
|
||||||
if isinstance(self.model, ModelWrapper):
|
|
||||||
model = self.model.module
|
|
||||||
assert (
|
|
||||||
model.__class__.__name__ in _supported_models.keys()
|
|
||||||
), f"Model {self.model.__class__.__name__} is not supported."
|
|
||||||
|
|
||||||
def _shardformer(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
model_policy: Policy,
|
|
||||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
|
||||||
stage_manager: PipelineStageManager = None,
|
|
||||||
tp_group: ProcessGroupMesh = None,
|
|
||||||
) -> nn.Module:
|
|
||||||
"""
|
|
||||||
Initialize ShardConfig and replace the model with shardformer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): Path or nn.Module of this model.
|
|
||||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
|
||||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
|
||||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
nn.Module: The model optimized by Shardformer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
shardconfig = ShardConfig(
|
|
||||||
tensor_parallel_process_group=tp_group,
|
|
||||||
pipeline_stage_manager=stage_manager,
|
|
||||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
|
||||||
enable_fused_normalization=False,
|
|
||||||
enable_all_optimization=False,
|
|
||||||
enable_flash_attention=False,
|
|
||||||
enable_jit_fused=False,
|
|
||||||
enable_sequence_parallelism=False,
|
|
||||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
|
|
||||||
)
|
|
||||||
shardformer = ShardFormer(shard_config=shardconfig)
|
|
||||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
|
||||||
return shard_model
|
|
||||||
|
|
||||||
def enable_spec_dec(
|
|
||||||
self,
|
|
||||||
drafter_model: nn.Module = None,
|
|
||||||
n_spec_tokens: int = None,
|
|
||||||
use_glide_drafter: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
|
||||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
|
||||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
|
||||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
|
||||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
|
||||||
If True, the drafter model will be replaced by a glide model.
|
|
||||||
|
|
||||||
```python
|
|
||||||
...
|
|
||||||
engine = InferenceEngine(model, tokenizer, inference_config)
|
|
||||||
|
|
||||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
|
||||||
engine.generate(...) # Speculative Decoding
|
|
||||||
|
|
||||||
engine.disable_spec_dec()
|
|
||||||
engine.generate(...) # Normal generation
|
|
||||||
|
|
||||||
engine.enable_spec_dec()
|
|
||||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
|
||||||
engine.clear_spec_dec()
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
if drafter_model is None and self.drafter is None:
|
|
||||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
|
||||||
if n_spec_tokens is not None:
|
|
||||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
|
||||||
self.n_spec_tokens = n_spec_tokens
|
|
||||||
if drafter_model is not None:
|
|
||||||
assert isinstance(drafter_model, nn.Module)
|
|
||||||
# overwrite the drafter, if exists
|
|
||||||
self.clear_spec_dec()
|
|
||||||
self.drafter_model = drafter_model
|
|
||||||
self.drafter = Drafter(
|
|
||||||
self.drafter_model,
|
|
||||||
self.tokenizer,
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if the provided drafter model is compatible with GLIDE structure
|
|
||||||
# when `use_glide_drafter` is set to True
|
|
||||||
if (
|
|
||||||
use_glide_drafter
|
|
||||||
and hasattr(drafter_model, "model")
|
|
||||||
and hasattr(drafter_model.model, "layers")
|
|
||||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
|
||||||
):
|
|
||||||
self.use_glide = use_glide_drafter
|
|
||||||
elif use_glide_drafter:
|
|
||||||
self.logger.warning(
|
|
||||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
|
||||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
|
||||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
|
||||||
)
|
|
||||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
|
||||||
# using speculative decoding for subsequent generations
|
|
||||||
self.use_spec_dec = True
|
|
||||||
|
|
||||||
def disable_spec_dec(self) -> None:
|
|
||||||
"""Disable using speculative decoding for subsequent generations."""
|
|
||||||
self.request_handler.unset_spec_dec_mode()
|
|
||||||
# set back to the maximum number of tokens to speculate
|
|
||||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
|
||||||
self.use_glide = False
|
|
||||||
self.use_spec_dec = False
|
|
||||||
|
|
||||||
def clear_spec_dec(self) -> None:
|
|
||||||
"""Clear relatable structures of speculative decoding, if exist."""
|
|
||||||
if self.use_spec_dec:
|
|
||||||
self.disable_spec_dec()
|
|
||||||
if self.drafter_model or self.drafter:
|
|
||||||
self.drafter_model = None
|
|
||||||
self.drafter = None
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
self.use_glide = False
|
|
||||||
self.use_spec_dec = False
|
|
||||||
|
|
||||||
def steps_spec_dec(self) -> List[Sequence]:
|
|
||||||
"""
|
|
||||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
|
||||||
with many steps of speculating by a drafter model as well as verifying by a main model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Sequence]: finished sequences generated by one step.
|
|
||||||
"""
|
|
||||||
batch = self.request_handler.schedule() # prefill batch
|
|
||||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
|
||||||
|
|
||||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
|
||||||
|
|
||||||
if input_meta_data.use_cuda_graph:
|
|
||||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
|
||||||
else:
|
|
||||||
model_executable = self.model
|
|
||||||
|
|
||||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
|
||||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
|
||||||
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
|
||||||
next_token_ids_spec = drafter_out.next_tokens
|
|
||||||
drafter_past_key_values = drafter_out.past_key_values
|
|
||||||
|
|
||||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
|
||||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
|
||||||
# append new inputs to the batch, temporarily
|
|
||||||
batch.append_batch_tokens(next_tokens)
|
|
||||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
|
||||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
|
||||||
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
|
||||||
|
|
||||||
finished_sequences = self.request_handler.update()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# HACK Retrieve the running batch
|
|
||||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
|
||||||
batch = self.request_handler.running_bb # running batch
|
|
||||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
|
||||||
|
|
||||||
# 3. Decoding - Drafter model speculates `n` tokens
|
|
||||||
glide_input = None
|
|
||||||
if self.use_glide:
|
|
||||||
glide_input = GlideInput(
|
|
||||||
batch.get_block_table_tensor(),
|
|
||||||
self.k_cache[-1], # use kv cahces of the last layer
|
|
||||||
self.v_cache[-1],
|
|
||||||
batch.get_sequence_lengths(),
|
|
||||||
n_spec_tokens=self.n_spec_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
drafter_out = self.drafter.speculate(
|
|
||||||
input_token_ids,
|
|
||||||
self.n_spec_tokens,
|
|
||||||
drafter_past_key_values,
|
|
||||||
glide_input=glide_input,
|
|
||||||
)
|
|
||||||
next_token_ids_spec = drafter_out.next_tokens
|
|
||||||
drafter_past_key_values = drafter_out.past_key_values
|
|
||||||
drafter_spec_length = drafter_out.speculated_length
|
|
||||||
|
|
||||||
for next_token_id_spec in next_token_ids_spec:
|
|
||||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
|
||||||
cur_length = batch.seq_lengths[0].item()
|
|
||||||
if already_allocated_kv_len < cur_length:
|
|
||||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
|
||||||
already_allocated_kv_len = cur_length
|
|
||||||
|
|
||||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
|
||||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
|
||||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
|
||||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
|
||||||
|
|
||||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
|
||||||
|
|
||||||
# 5. Compare and process the results
|
|
||||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
|
||||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
|
||||||
|
|
||||||
# revoke appended tokens for each Sequence in the current batch
|
|
||||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
|
||||||
|
|
||||||
# append the last correct token generated by the main model
|
|
||||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
|
||||||
|
|
||||||
# trim past key values of the drafter model
|
|
||||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
|
||||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepare inputs for the next round of speculation
|
|
||||||
n = 1 if n_matches < drafter_spec_length else 2
|
|
||||||
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
|
||||||
|
|
||||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
|
||||||
finished_sequences = self.request_handler.update()
|
|
||||||
if len(finished_sequences) > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Reset back the number of speculated tokens of the batch,
|
|
||||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
|
||||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
|
||||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
|
||||||
|
|
||||||
return finished_sequences
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
request_ids: Union[List[int], int] = None,
|
request_ids: Union[List[int], int] = None,
|
||||||
prompts: Union[List[str], str] = None,
|
prompts: Union[List[str], str] = None,
|
||||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
*args,
|
||||||
return_token_ids: bool = False,
|
**kwargs,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
|
||||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
|
||||||
"""
|
"""
|
||||||
Executing the inference step.
|
Executing the inference step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
|
||||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
|
||||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
assert self.engine is not None, "Please init Engine first"
|
||||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
|
||||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
|
||||||
self.add_request(
|
|
||||||
request_ids=request_ids,
|
|
||||||
prompts=prompts,
|
|
||||||
prompts_token_ids=prompts_token_ids,
|
|
||||||
**gen_config_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_seqs_list = []
|
|
||||||
total_tokens_list = []
|
|
||||||
|
|
||||||
# intuition: If user provide a generation config, we should replace the existing one.
|
|
||||||
if generation_config is not None:
|
|
||||||
self.generation_config = generation_config
|
|
||||||
self.generation_config_dict = gen_config_dict
|
|
||||||
|
|
||||||
if self.use_spec_dec:
|
|
||||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
|
||||||
while self.request_handler.check_unfinished_seqs():
|
|
||||||
output_seqs_list += self.steps_spec_dec()
|
|
||||||
else:
|
|
||||||
while self.request_handler.check_unfinished_seqs():
|
|
||||||
output_seqs_list += self.step()
|
|
||||||
|
|
||||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
|
||||||
|
|
||||||
for seq in output_seqs_list:
|
|
||||||
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
|
||||||
|
|
||||||
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
|
||||||
|
|
||||||
if return_token_ids:
|
|
||||||
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
|
||||||
return output_str, output_tokens_list
|
|
||||||
else:
|
|
||||||
return output_str
|
|
||||||
|
|
||||||
@property
|
|
||||||
def has_prompt_template(self) -> bool:
|
|
||||||
""" """
|
|
||||||
return self.inference_config.prompt_template is not None
|
|
||||||
|
|
||||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
|
||||||
"""
|
|
||||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
|
||||||
"""
|
|
||||||
assert (
|
|
||||||
self.has_prompt_template
|
|
||||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
|
||||||
|
|
||||||
if isinstance(prompts, (list, tuple)):
|
|
||||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
|
||||||
elif isinstance(prompts, str):
|
|
||||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_ids: Union[List[int], int] = None,
|
request_ids: Union[List[int], int] = None,
|
||||||
prompts: Union[List[str], str] = None,
|
prompts: Union[List[str], str] = None,
|
||||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -630,168 +98,36 @@ class InferenceEngine:
|
|||||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||||
|
kwargs: for LLM, it could be max_length, max_new_tokens, etc
|
||||||
|
for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
|
||||||
"""
|
"""
|
||||||
|
assert self.engine is not None, "Please init Engine first"
|
||||||
|
self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
|
||||||
|
|
||||||
# apply the prompt template to the input prompts
|
def step(self):
|
||||||
|
assert self.engine is not None, "Please init Engine first"
|
||||||
|
return self.engine.step()
|
||||||
|
|
||||||
if self.has_prompt_template and prompts is not None:
|
def __getattr__(self, name):
|
||||||
prompts = self.format_prompt(prompts)
|
"""
|
||||||
|
The Design logic of getattr, setattr:
|
||||||
block_size = self.inference_config.block_size
|
1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
|
||||||
|
2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
|
||||||
if request_ids is not None and not isinstance(request_ids, list):
|
So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
|
||||||
request_ids = [request_ids]
|
"""
|
||||||
|
if self.__dict__.get("_initialized", False):
|
||||||
if prompts is not None and not isinstance(prompts, list):
|
if name in self.__dict__:
|
||||||
prompts = [prompts]
|
return self.__dict__[name]
|
||||||
|
|
||||||
if prompts_token_ids is None:
|
|
||||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
|
||||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
|
||||||
"input_ids"
|
|
||||||
]
|
|
||||||
|
|
||||||
# list of torch Tensor
|
|
||||||
if isinstance(prompts_token_ids, list):
|
|
||||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
|
||||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
|
||||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
|
||||||
prompts_token_ids = prompts_token_ids.tolist()
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
|
||||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
|
||||||
|
|
||||||
prompts_num = len(prompts_token_ids)
|
|
||||||
|
|
||||||
for i in range(prompts_num):
|
|
||||||
if request_ids:
|
|
||||||
assert isinstance(
|
|
||||||
request_ids[0], int
|
|
||||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
|
||||||
assert len(request_ids) == prompts_num
|
|
||||||
request_id = request_ids[i]
|
|
||||||
else:
|
else:
|
||||||
request_id = next(self.counter)
|
return getattr(self.engine, name)
|
||||||
if prompts == None:
|
else:
|
||||||
prompt = None
|
return self.__dict__[name]
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if self.__dict__.get("_initialized", False):
|
||||||
|
if name in self.__dict__:
|
||||||
|
self.__dict__[name] = value
|
||||||
else:
|
else:
|
||||||
prompt = prompts[i]
|
setattr(self.engine, name, value)
|
||||||
|
|
||||||
max_length = kwargs.get("max_length", None)
|
|
||||||
max_new_tokens = kwargs.get("max_new_tokens", None)
|
|
||||||
if max_length is None and max_new_tokens is None:
|
|
||||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
|
||||||
elif max_length is not None:
|
|
||||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
|
||||||
|
|
||||||
if not self.inference_config.enable_streamingllm:
|
|
||||||
assert (
|
|
||||||
self.inference_config.max_output_len >= max_new_tokens
|
|
||||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
|
||||||
|
|
||||||
sequence = Sequence(
|
|
||||||
request_id,
|
|
||||||
prompt,
|
|
||||||
prompts_token_ids[i],
|
|
||||||
block_size,
|
|
||||||
None,
|
|
||||||
self.tokenizer.eos_token_id,
|
|
||||||
self.tokenizer.pad_token_id,
|
|
||||||
max_output_len=max_new_tokens,
|
|
||||||
ignore_eos=self.inference_config.ignore_eos,
|
|
||||||
)
|
|
||||||
self.request_handler.add_sequence(sequence)
|
|
||||||
|
|
||||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
|
||||||
input_ids = batch.get_1D_inputs()
|
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
|
||||||
|
|
||||||
if batch.is_prompts:
|
|
||||||
n_tokens = sequence_lengths.sum().item()
|
|
||||||
else:
|
else:
|
||||||
n_tokens = batch.current_batch_size
|
self.__dict__[name] = value
|
||||||
if batch.use_spec_dec:
|
|
||||||
n_tokens = batch.num_tokens_to_verify + 1
|
|
||||||
assert n_tokens == input_ids.size(0)
|
|
||||||
n_tokens = n_tokens * batch.current_batch_size
|
|
||||||
output_tensor = torch.zeros(
|
|
||||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_token_ids = None
|
|
||||||
if (
|
|
||||||
self.generation_config.repetition_penalty != 1.0
|
|
||||||
or self.generation_config.no_repeat_ngram_size > 0
|
|
||||||
or self.generation_config.forced_eos_token_id is not None
|
|
||||||
):
|
|
||||||
batch_token_ids = batch.batch_token_ids
|
|
||||||
|
|
||||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
|
||||||
use_cuda_graph = False
|
|
||||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
|
||||||
use_cuda_graph = True
|
|
||||||
|
|
||||||
input_meta_data = InputMetaData(
|
|
||||||
block_tables=batch.get_block_table_tensor(),
|
|
||||||
sequence_lengths=sequence_lengths,
|
|
||||||
fd_inter_tensor=batch.fd_inter_tensor,
|
|
||||||
batch_size=batch.current_batch_size,
|
|
||||||
is_prompts=batch.is_prompts,
|
|
||||||
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
|
||||||
use_cuda_graph=use_cuda_graph,
|
|
||||||
high_precision=self.high_precision,
|
|
||||||
kv_seq_len=sequence_lengths.max().item(),
|
|
||||||
head_dim=batch.head_dim,
|
|
||||||
dtype=batch.dtype,
|
|
||||||
use_spec_dec=batch.use_spec_dec,
|
|
||||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
|
||||||
batch_token_ids=batch_token_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
return input_ids, output_tensor, input_meta_data
|
|
||||||
|
|
||||||
def step(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
In each step, do the follows:
|
|
||||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
|
||||||
2. Get the input, inputinfo and output placeholder from the batchbucket
|
|
||||||
3. Run model to generate the next token
|
|
||||||
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
|
||||||
5. Decode and return finished sequences.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: Decoded finished sequences generated by one step.
|
|
||||||
"""
|
|
||||||
|
|
||||||
batch = self.request_handler.schedule()
|
|
||||||
|
|
||||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
|
||||||
|
|
||||||
if input_meta_data.use_cuda_graph:
|
|
||||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
|
||||||
else:
|
|
||||||
model_executable = self.model
|
|
||||||
|
|
||||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
|
||||||
if self.inference_config.pad_input:
|
|
||||||
logits = logits[:, -1, :]
|
|
||||||
|
|
||||||
if self.inference_config.enable_streamingllm:
|
|
||||||
updated_block_ids = batch.streamingllm_update_batch(
|
|
||||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
|
||||||
)
|
|
||||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
|
||||||
|
|
||||||
next_tokens = search_tokens(
|
|
||||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
|
||||||
)
|
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
|
||||||
finished_sequences = self.request_handler.update()
|
|
||||||
|
|
||||||
return finished_sequences
|
|
||||||
|
758
colossalai/inference/core/llm_engine.py
Normal file
758
colossalai/inference/core/llm_engine.py
Normal file
@ -0,0 +1,758 @@
|
|||||||
|
import time
|
||||||
|
from itertools import count
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import distributed as dist
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
GenerationConfig,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
)
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
|
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||||
|
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||||
|
from colossalai.inference.modeling.policy import model_policy_map
|
||||||
|
from colossalai.inference.sampler import search_tokens
|
||||||
|
from colossalai.inference.spec import Drafter, GlideInput
|
||||||
|
from colossalai.inference.struct import Sequence
|
||||||
|
from colossalai.inference.utils import get_model_size, has_index_file
|
||||||
|
from colossalai.interface import ModelWrapper
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
|
from .base_engine import BaseEngine
|
||||||
|
from .request_handler import RequestHandler
|
||||||
|
|
||||||
|
PP_AXIS, TP_AXIS = 0, 1
|
||||||
|
|
||||||
|
_supported_models = {
|
||||||
|
"LlamaForCausalLM": LlamaForCausalLM,
|
||||||
|
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEngine(BaseEngine):
|
||||||
|
"""
|
||||||
|
InferenceEngine which manages the inference process..
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||||
|
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||||
|
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||||
|
verbose (bool): Determine whether or not to log the generation process.
|
||||||
|
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_or_path: nn.Module | str,
|
||||||
|
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
|
||||||
|
inference_config: InferenceConfig = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
model_policy: Policy | type[Policy] = None,
|
||||||
|
) -> None:
|
||||||
|
self.inference_config = inference_config
|
||||||
|
self.dtype = inference_config.dtype
|
||||||
|
self.high_precision = inference_config.high_precision
|
||||||
|
|
||||||
|
self.verbose = verbose
|
||||||
|
self.logger = get_dist_logger(__name__)
|
||||||
|
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||||
|
|
||||||
|
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||||
|
|
||||||
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||||
|
self.generation_config_dict = self.generation_config.to_dict()
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||||
|
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||||
|
# DISCUSS maybe move this into batch info?
|
||||||
|
|
||||||
|
self.counter = count()
|
||||||
|
|
||||||
|
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||||
|
if self.use_cuda_graph:
|
||||||
|
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||||
|
self.graph_memory_pool = None # Set during graph capture.
|
||||||
|
if verbose:
|
||||||
|
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||||
|
|
||||||
|
self.capture_model(self.k_cache, self.v_cache)
|
||||||
|
|
||||||
|
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||||
|
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||||
|
|
||||||
|
self.drafter_model = None
|
||||||
|
self.drafter = None
|
||||||
|
self.use_glide = False
|
||||||
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
|
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
def init_model(
|
||||||
|
self,
|
||||||
|
model_or_path: Union[nn.Module, str],
|
||||||
|
model_policy: Union[Policy, Type[Policy]] = None,
|
||||||
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Shard model or/and Load weight
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||||
|
model_policy (Policy): the policy to replace the model.
|
||||||
|
model_inference_config: the configuration for modeling initialization when inference.
|
||||||
|
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||||
|
"""
|
||||||
|
pretrained_path = None
|
||||||
|
if isinstance(model_or_path, str):
|
||||||
|
import colossalai.interface.pretrained as pretrained_utils
|
||||||
|
|
||||||
|
try:
|
||||||
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||||
|
arch = getattr(hf_config, "architectures")[0]
|
||||||
|
if arch in _supported_models.keys():
|
||||||
|
if arch == "BaichuanForCausalLM":
|
||||||
|
self.logger.warning(
|
||||||
|
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||||
|
)
|
||||||
|
ctx = LazyInitContext(default_device="cuda")
|
||||||
|
with ctx:
|
||||||
|
model = _supported_models[arch].from_pretrained(
|
||||||
|
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||||
|
)
|
||||||
|
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||||
|
else:
|
||||||
|
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||||
|
raise ValueError(f"Model {arch} is not supported.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(
|
||||||
|
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = model_or_path
|
||||||
|
|
||||||
|
self.model_config = model.config
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
|
self.device = get_accelerator().get_current_device()
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(f"the device is {self.device}")
|
||||||
|
|
||||||
|
model = model.to(self.dtype).eval()
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(
|
||||||
|
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_policy is None:
|
||||||
|
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||||
|
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||||
|
model_policy = model_policy_map.get(model_policy_key)
|
||||||
|
|
||||||
|
if not isinstance(model_policy, Policy):
|
||||||
|
try:
|
||||||
|
model_policy = model_policy()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||||
|
|
||||||
|
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||||
|
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||||
|
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
|
|
||||||
|
self.model = self._shardformer(
|
||||||
|
model,
|
||||||
|
model_policy,
|
||||||
|
model_shard_infer_config,
|
||||||
|
None,
|
||||||
|
tp_group=tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = ModelWrapper(model).to(self.device)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(
|
||||||
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if pretrained_path:
|
||||||
|
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||||
|
|
||||||
|
cpt_io = InferCheckpoint_io()
|
||||||
|
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||||
|
assert if_has_index_file, "the model path is invalid"
|
||||||
|
cpt_io.load_model(self.model, model_index_file)
|
||||||
|
|
||||||
|
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||||
|
peak_memory = init_gpu_memory - free_gpu_memory
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(
|
||||||
|
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||||
|
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||||
|
|
||||||
|
t_capture_begin = time.perf_counter()
|
||||||
|
|
||||||
|
block_size = self.inference_config.block_size
|
||||||
|
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||||
|
|
||||||
|
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||||
|
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||||
|
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||||
|
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||||
|
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||||
|
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||||
|
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
||||||
|
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||||
|
self.graph_block_tables[0, :] = np.arange(
|
||||||
|
0, max_num_blocks
|
||||||
|
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||||
|
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||||
|
)
|
||||||
|
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
||||||
|
|
||||||
|
max_num_seqs = self.inference_config.max_batch_size
|
||||||
|
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||||
|
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||||
|
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||||
|
sequence_lengths[0] = torch.tensor(
|
||||||
|
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
|
# memory usage of CUDA graph.
|
||||||
|
for batch_size in reversed(batch_size_capture_list):
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||||
|
|
||||||
|
input_meta_data = InputMetaData(
|
||||||
|
block_tables=block_tables[:batch_size],
|
||||||
|
sequence_lengths=sequence_lengths[:batch_size],
|
||||||
|
fd_inter_tensor=fd_inter_tensor,
|
||||||
|
batch_size=batch_size,
|
||||||
|
is_prompts=False,
|
||||||
|
use_cuda_graph=True,
|
||||||
|
high_precision=False,
|
||||||
|
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_runner = CUDAGraphRunner(self.model)
|
||||||
|
graph_runner.capture(
|
||||||
|
input_tokens_ids[:batch_size],
|
||||||
|
output_tensor[:batch_size],
|
||||||
|
input_meta_data,
|
||||||
|
k_caches=k_cache,
|
||||||
|
v_caches=v_cache,
|
||||||
|
memory_pool=self.graph_memory_pool,
|
||||||
|
)
|
||||||
|
self.graph_memory_pool = graph_runner.graph.pool()
|
||||||
|
self.graph_runners[batch_size] = graph_runner
|
||||||
|
|
||||||
|
t_capture_end = time.perf_counter()
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
"""Verify the input args"""
|
||||||
|
if not isinstance(self.inference_config, InferenceConfig):
|
||||||
|
raise TypeError("Invalid type of inference config provided.")
|
||||||
|
if not isinstance(self.model, nn.Module):
|
||||||
|
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||||
|
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||||
|
raise TypeError(
|
||||||
|
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||||
|
)
|
||||||
|
if isinstance(self.model, ModelWrapper):
|
||||||
|
model = self.model.module
|
||||||
|
assert (
|
||||||
|
model.__class__.__name__ in _supported_models.keys()
|
||||||
|
), f"Model {self.model.__class__.__name__} is not supported."
|
||||||
|
|
||||||
|
def enable_spec_dec(
|
||||||
|
self,
|
||||||
|
drafter_model: nn.Module = None,
|
||||||
|
n_spec_tokens: int = None,
|
||||||
|
use_glide_drafter: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||||
|
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||||
|
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||||
|
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||||
|
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||||
|
If True, the drafter model will be replaced by a glide model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||||
|
|
||||||
|
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||||
|
engine.generate(...) # Speculative Decoding
|
||||||
|
|
||||||
|
engine.disable_spec_dec()
|
||||||
|
engine.generate(...) # Normal generation
|
||||||
|
|
||||||
|
engine.enable_spec_dec()
|
||||||
|
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||||
|
engine.clear_spec_dec()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
if drafter_model is None and self.drafter is None:
|
||||||
|
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||||
|
if n_spec_tokens is not None:
|
||||||
|
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||||
|
self.n_spec_tokens = n_spec_tokens
|
||||||
|
if drafter_model is not None:
|
||||||
|
assert isinstance(drafter_model, nn.Module)
|
||||||
|
# overwrite the drafter, if exists
|
||||||
|
self.clear_spec_dec()
|
||||||
|
self.drafter_model = drafter_model
|
||||||
|
self.drafter = Drafter(
|
||||||
|
self.drafter_model,
|
||||||
|
self.tokenizer,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if the provided drafter model is compatible with GLIDE structure
|
||||||
|
# when `use_glide_drafter` is set to True
|
||||||
|
if (
|
||||||
|
use_glide_drafter
|
||||||
|
and hasattr(drafter_model, "model")
|
||||||
|
and hasattr(drafter_model.model, "layers")
|
||||||
|
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||||
|
):
|
||||||
|
self.use_glide = use_glide_drafter
|
||||||
|
elif use_glide_drafter:
|
||||||
|
self.logger.warning(
|
||||||
|
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||||
|
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||||
|
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||||
|
)
|
||||||
|
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||||
|
# using speculative decoding for subsequent generations
|
||||||
|
self.use_spec_dec = True
|
||||||
|
|
||||||
|
def disable_spec_dec(self) -> None:
|
||||||
|
"""Disable using speculative decoding for subsequent generations."""
|
||||||
|
self.request_handler.unset_spec_dec_mode()
|
||||||
|
# set back to the maximum number of tokens to speculate
|
||||||
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
|
self.use_glide = False
|
||||||
|
self.use_spec_dec = False
|
||||||
|
|
||||||
|
def clear_spec_dec(self) -> None:
|
||||||
|
"""Clear relatable structures of speculative decoding, if exist."""
|
||||||
|
if self.use_spec_dec:
|
||||||
|
self.disable_spec_dec()
|
||||||
|
if self.drafter_model or self.drafter:
|
||||||
|
self.drafter_model = None
|
||||||
|
self.drafter = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.use_glide = False
|
||||||
|
self.use_spec_dec = False
|
||||||
|
|
||||||
|
def steps_spec_dec(self) -> List[Sequence]:
|
||||||
|
"""
|
||||||
|
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||||
|
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Sequence]: finished sequences generated by one step.
|
||||||
|
"""
|
||||||
|
batch = self.request_handler.schedule() # prefill batch
|
||||||
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||||
|
|
||||||
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
|
|
||||||
|
if input_meta_data.use_cuda_graph:
|
||||||
|
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||||
|
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||||
|
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
||||||
|
next_token_ids_spec = drafter_out.next_tokens
|
||||||
|
drafter_past_key_values = drafter_out.past_key_values
|
||||||
|
|
||||||
|
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||||
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||||
|
# append new inputs to the batch, temporarily
|
||||||
|
batch.append_batch_tokens(next_tokens)
|
||||||
|
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||||
|
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||||
|
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
||||||
|
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# HACK Retrieve the running batch
|
||||||
|
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||||
|
batch = self.request_handler.running_bb # running batch
|
||||||
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||||
|
|
||||||
|
# 3. Decoding - Drafter model speculates `n` tokens
|
||||||
|
glide_input = None
|
||||||
|
if self.use_glide:
|
||||||
|
glide_input = GlideInput(
|
||||||
|
batch.get_block_table_tensor(),
|
||||||
|
self.k_cache[-1], # use kv cahces of the last layer
|
||||||
|
self.v_cache[-1],
|
||||||
|
batch.get_sequence_lengths(),
|
||||||
|
n_spec_tokens=self.n_spec_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
drafter_out = self.drafter.speculate(
|
||||||
|
input_token_ids,
|
||||||
|
self.n_spec_tokens,
|
||||||
|
drafter_past_key_values,
|
||||||
|
glide_input=glide_input,
|
||||||
|
)
|
||||||
|
next_token_ids_spec = drafter_out.next_tokens
|
||||||
|
drafter_past_key_values = drafter_out.past_key_values
|
||||||
|
drafter_spec_length = drafter_out.speculated_length
|
||||||
|
|
||||||
|
for next_token_id_spec in next_token_ids_spec:
|
||||||
|
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||||
|
cur_length = batch.seq_lengths[0].item()
|
||||||
|
if already_allocated_kv_len < cur_length:
|
||||||
|
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||||
|
already_allocated_kv_len = cur_length
|
||||||
|
|
||||||
|
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||||
|
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||||
|
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||||
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
|
||||||
|
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||||
|
|
||||||
|
# 5. Compare and process the results
|
||||||
|
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||||
|
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||||
|
|
||||||
|
# revoke appended tokens for each Sequence in the current batch
|
||||||
|
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||||
|
|
||||||
|
# append the last correct token generated by the main model
|
||||||
|
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||||
|
|
||||||
|
# trim past key values of the drafter model
|
||||||
|
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||||
|
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# prepare inputs for the next round of speculation
|
||||||
|
n = 1 if n_matches < drafter_spec_length else 2
|
||||||
|
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
||||||
|
|
||||||
|
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
if len(finished_sequences) > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset back the number of speculated tokens of the batch,
|
||||||
|
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||||
|
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||||
|
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||||
|
|
||||||
|
return finished_sequences
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
request_ids: Union[List[int], int] = None,
|
||||||
|
prompts: Union[List[str], str] = None,
|
||||||
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
|
return_token_ids: bool = False,
|
||||||
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||||
|
"""
|
||||||
|
Executing the inference step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||||
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||||
|
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||||
|
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||||
|
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||||
|
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||||
|
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
|
self.add_request(
|
||||||
|
request_ids=request_ids,
|
||||||
|
prompts=prompts,
|
||||||
|
prompts_token_ids=prompts_token_ids,
|
||||||
|
**gen_config_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_seqs_list = []
|
||||||
|
total_tokens_list = []
|
||||||
|
|
||||||
|
# intuition: If user provide a generation config, we should replace the existing one.
|
||||||
|
if generation_config is not None:
|
||||||
|
self.generation_config = generation_config
|
||||||
|
self.generation_config_dict = gen_config_dict
|
||||||
|
|
||||||
|
if self.use_spec_dec:
|
||||||
|
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||||
|
while self.request_handler.check_unfinished_reqs():
|
||||||
|
output_seqs_list += self.steps_spec_dec()
|
||||||
|
else:
|
||||||
|
while self.request_handler.check_unfinished_reqs():
|
||||||
|
output_seqs_list += self.step()
|
||||||
|
|
||||||
|
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||||
|
|
||||||
|
for seq in output_seqs_list:
|
||||||
|
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||||
|
|
||||||
|
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
||||||
|
|
||||||
|
if return_token_ids:
|
||||||
|
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
||||||
|
return output_str, output_tokens_list
|
||||||
|
else:
|
||||||
|
return output_str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_prompt_template(self) -> bool:
|
||||||
|
""" """
|
||||||
|
return self.inference_config.prompt_template is not None
|
||||||
|
|
||||||
|
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||||
|
"""
|
||||||
|
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
self.has_prompt_template
|
||||||
|
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||||
|
|
||||||
|
if isinstance(prompts, (list, tuple)):
|
||||||
|
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||||
|
elif isinstance(prompts, str):
|
||||||
|
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||||
|
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
request_ids: Union[List[int], int] = None,
|
||||||
|
prompts: Union[List[str], str] = None,
|
||||||
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||||
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||||
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# apply the prompt template to the input prompts
|
||||||
|
|
||||||
|
if self.has_prompt_template and prompts is not None:
|
||||||
|
prompts = self.format_prompt(prompts)
|
||||||
|
|
||||||
|
block_size = self.inference_config.block_size
|
||||||
|
|
||||||
|
if request_ids is not None and not isinstance(request_ids, list):
|
||||||
|
request_ids = [request_ids]
|
||||||
|
|
||||||
|
if prompts is not None and not isinstance(prompts, list):
|
||||||
|
prompts = [prompts]
|
||||||
|
|
||||||
|
if prompts_token_ids is None:
|
||||||
|
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||||
|
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
|
||||||
|
# list of torch Tensor
|
||||||
|
if isinstance(prompts_token_ids, list):
|
||||||
|
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||||
|
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||||
|
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||||
|
prompts_token_ids = prompts_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||||
|
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||||
|
|
||||||
|
prompts_num = len(prompts_token_ids)
|
||||||
|
|
||||||
|
for i in range(prompts_num):
|
||||||
|
if request_ids:
|
||||||
|
assert isinstance(
|
||||||
|
request_ids[0], int
|
||||||
|
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||||
|
assert len(request_ids) == prompts_num
|
||||||
|
request_id = request_ids[i]
|
||||||
|
else:
|
||||||
|
request_id = next(self.counter)
|
||||||
|
if prompts == None:
|
||||||
|
prompt = None
|
||||||
|
else:
|
||||||
|
prompt = prompts[i]
|
||||||
|
|
||||||
|
max_length = kwargs.get("max_length", None)
|
||||||
|
max_new_tokens = kwargs.get("max_new_tokens", None)
|
||||||
|
if max_length is None and max_new_tokens is None:
|
||||||
|
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
||||||
|
elif max_length is not None:
|
||||||
|
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||||
|
|
||||||
|
if not self.inference_config.enable_streamingllm:
|
||||||
|
assert (
|
||||||
|
self.inference_config.max_output_len >= max_new_tokens
|
||||||
|
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||||
|
|
||||||
|
sequence = Sequence(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
prompts_token_ids[i],
|
||||||
|
block_size,
|
||||||
|
None,
|
||||||
|
self.tokenizer.eos_token_id,
|
||||||
|
self.tokenizer.pad_token_id,
|
||||||
|
max_output_len=max_new_tokens,
|
||||||
|
ignore_eos=self.inference_config.ignore_eos,
|
||||||
|
)
|
||||||
|
self.request_handler.add_sequence(sequence)
|
||||||
|
|
||||||
|
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||||
|
input_ids = batch.get_1D_inputs()
|
||||||
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
|
|
||||||
|
if batch.is_prompts:
|
||||||
|
n_tokens = sequence_lengths.sum().item()
|
||||||
|
else:
|
||||||
|
n_tokens = batch.current_batch_size
|
||||||
|
if batch.use_spec_dec:
|
||||||
|
n_tokens = batch.num_tokens_to_verify + 1
|
||||||
|
assert n_tokens == input_ids.size(0)
|
||||||
|
n_tokens = n_tokens * batch.current_batch_size
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_token_ids = None
|
||||||
|
if (
|
||||||
|
self.generation_config.repetition_penalty != 1.0
|
||||||
|
or self.generation_config.no_repeat_ngram_size > 0
|
||||||
|
or self.generation_config.forced_eos_token_id is not None
|
||||||
|
):
|
||||||
|
batch_token_ids = batch.batch_token_ids
|
||||||
|
|
||||||
|
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
|
use_cuda_graph = False
|
||||||
|
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||||
|
use_cuda_graph = True
|
||||||
|
|
||||||
|
input_meta_data = InputMetaData(
|
||||||
|
block_tables=batch.get_block_table_tensor(),
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
|
fd_inter_tensor=batch.fd_inter_tensor,
|
||||||
|
batch_size=batch.current_batch_size,
|
||||||
|
is_prompts=batch.is_prompts,
|
||||||
|
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||||
|
use_cuda_graph=use_cuda_graph,
|
||||||
|
high_precision=self.high_precision,
|
||||||
|
kv_seq_len=sequence_lengths.max().item(),
|
||||||
|
head_dim=batch.head_dim,
|
||||||
|
dtype=batch.dtype,
|
||||||
|
use_spec_dec=batch.use_spec_dec,
|
||||||
|
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||||
|
batch_token_ids=batch_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
return input_ids, output_tensor, input_meta_data
|
||||||
|
|
||||||
|
def step(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
In each step, do the follows:
|
||||||
|
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||||
|
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||||
|
3. Run model to generate the next token
|
||||||
|
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||||
|
5. Decode and return finished sequences.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: Decoded finished sequences generated by one step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
|
|
||||||
|
if input_meta_data.use_cuda_graph:
|
||||||
|
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
if self.inference_config.pad_input:
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if self.inference_config.enable_streamingllm:
|
||||||
|
updated_block_ids = batch.streamingllm_update_batch(
|
||||||
|
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||||
|
)
|
||||||
|
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||||
|
|
||||||
|
next_tokens = search_tokens(
|
||||||
|
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||||
|
)
|
||||||
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
return finished_sequences
|
@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
|||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||||
from colossalai.inference.struct import RequestStatus, Sequence
|
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
logger = get_dist_logger(__name__)
|
||||||
@ -98,7 +98,46 @@ class RunningList:
|
|||||||
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class NaiveRequestHandler:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.running_list: List[DiffusionSequence] = []
|
||||||
|
self.waiting_list: List[str] = []
|
||||||
|
|
||||||
|
def _has_waiting(self) -> bool:
|
||||||
|
return any(lst for lst in self.waiting_list)
|
||||||
|
|
||||||
|
def _has_running(self) -> bool:
|
||||||
|
return any(lst for lst in self.running_list)
|
||||||
|
|
||||||
|
def check_unfinished_reqs(self):
|
||||||
|
return self._has_waiting() or self._has_running()
|
||||||
|
|
||||||
|
def add_sequence(self, seq: DiffusionSequence):
|
||||||
|
"""
|
||||||
|
Add the request to waiting list.
|
||||||
|
"""
|
||||||
|
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
|
||||||
|
self.waiting_list.append(seq)
|
||||||
|
|
||||||
|
def _find_sequence(self, request_id: int) -> DiffusionSequence:
|
||||||
|
"""
|
||||||
|
Find the request by request_id.
|
||||||
|
"""
|
||||||
|
for lst in enumerate(self.waiting_list + self.running_list):
|
||||||
|
for seq in lst:
|
||||||
|
if seq.request_id == request_id:
|
||||||
|
return seq
|
||||||
|
return None
|
||||||
|
|
||||||
|
def schedule(self):
|
||||||
|
ret = None
|
||||||
|
if self._has_waiting:
|
||||||
|
ret = self.waiting_list[0]
|
||||||
|
self.waiting_list = self.waiting_list[1:]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class RequestHandler(NaiveRequestHandler):
|
||||||
"""
|
"""
|
||||||
RequestHandler is the core for handling existing requests and updating current batch.
|
RequestHandler is the core for handling existing requests and updating current batch.
|
||||||
During generation process, we call schedule function each iteration to update current batch.
|
During generation process, we call schedule function each iteration to update current batch.
|
||||||
@ -176,12 +215,12 @@ class RequestHandler:
|
|||||||
generated_token_size=inference_config.generated_token_size,
|
generated_token_size=inference_config.generated_token_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _has_running(self) -> bool:
|
||||||
|
return not self.running_bb.is_empty()
|
||||||
|
|
||||||
def _init_cache(self, model_config):
|
def _init_cache(self, model_config):
|
||||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||||
|
|
||||||
def _has_waiting(self) -> bool:
|
|
||||||
return any(lst for lst in self.waiting_list)
|
|
||||||
|
|
||||||
def get_kvcache(self):
|
def get_kvcache(self):
|
||||||
return self.cache_manager.get_kv_cache()
|
return self.cache_manager.get_kv_cache()
|
||||||
|
|
||||||
@ -318,7 +357,7 @@ class RequestHandler:
|
|||||||
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
|
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
|
||||||
seq.mark_finished()
|
seq.mark_finished()
|
||||||
|
|
||||||
def check_unfinished_seqs(self) -> bool:
|
def check_unfinished_reqs(self) -> bool:
|
||||||
return self._has_waiting() or not self.running_list.is_empty()
|
return self._has_waiting() or not self.running_list.is_empty()
|
||||||
|
|
||||||
def total_requests_in_batch_bucket(self) -> int:
|
def total_requests_in_batch_bucket(self) -> int:
|
||||||
|
54
colossalai/inference/modeling/models/diffusion.py
Normal file
54
colossalai/inference/modeling/models/diffusion.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPipe(nn.Module):
|
||||||
|
"""
|
||||||
|
This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, source_obj) -> None:
|
||||||
|
super(DiffusionPipe, self).__init__()
|
||||||
|
|
||||||
|
for k, v in source_obj.__dict__.items():
|
||||||
|
if isinstance(v, nn.Module):
|
||||||
|
self.add_module(k, v)
|
||||||
|
else:
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
skip_list = ["_execution_device", "to", "device"] # this
|
||||||
|
|
||||||
|
for name, member in inspect.getmembers(source_obj.__class__):
|
||||||
|
if name in skip_list:
|
||||||
|
continue
|
||||||
|
if not name.startswith("__") and not name.endswith("__"):
|
||||||
|
if isinstance(member, property):
|
||||||
|
setattr(self.__class__, name, member)
|
||||||
|
elif inspect.isfunction(member) or inspect.ismethod(member):
|
||||||
|
bound_method = types.MethodType(member, self)
|
||||||
|
setattr(self, name, bound_method)
|
||||||
|
elif not callable(member) and not isinstance(member, property):
|
||||||
|
setattr(self, name, member)
|
||||||
|
elif name == "__call__":
|
||||||
|
bound_method = types.MethodType(member, self)
|
||||||
|
setattr(self, "_forward", bound_method)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _execution_device(self):
|
||||||
|
r"""
|
||||||
|
Returns the device on which the pipeline's models will be executed. After calling
|
||||||
|
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
||||||
|
Accelerate's module hooks.
|
||||||
|
"""
|
||||||
|
# return self.device
|
||||||
|
return torch.device("cuda")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
next(self.parameters()).device
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self._forward(*args, **kwargs)
|
220
colossalai/inference/modeling/models/pixart_alpha.py
Normal file
220
colossalai/inference/modeling/models/pixart_alpha.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
# Code adapted from:
|
||||||
|
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
|
||||||
|
|
||||||
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import torch
|
||||||
|
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
|
||||||
|
ASPECT_RATIO_256_BIN,
|
||||||
|
ASPECT_RATIO_512_BIN,
|
||||||
|
ASPECT_RATIO_1024_BIN,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
from .diffusion import DiffusionPipe
|
||||||
|
|
||||||
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def pixart_alpha_forward(
|
||||||
|
self: DiffusionPipe,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
num_inference_steps: int = 20,
|
||||||
|
timesteps: List[int] = None,
|
||||||
|
sigmas: List[float] = None,
|
||||||
|
guidance_scale: float = 4.5,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
clean_caption: bool = True,
|
||||||
|
use_resolution_binning: bool = True,
|
||||||
|
max_sequence_length: int = 120,
|
||||||
|
**kwargs,
|
||||||
|
) -> PIL.Image:
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||||
|
if use_resolution_binning:
|
||||||
|
if self.transformer.config.sample_size == 128:
|
||||||
|
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
||||||
|
elif self.transformer.config.sample_size == 64:
|
||||||
|
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||||
|
elif self.transformer.config.sample_size == 32:
|
||||||
|
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid sample size")
|
||||||
|
orig_height, orig_width = height, width
|
||||||
|
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||||
|
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
negative_prompt,
|
||||||
|
callback_steps,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
prompt_attention_mask,
|
||||||
|
negative_prompt_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Default height and width to transformer
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
prompt_attention_mask,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
negative_prompt_attention_mask,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=device,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
prompt_attention_mask=prompt_attention_mask,
|
||||||
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
|
clean_caption=clean_caption,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
|
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
|
||||||
|
|
||||||
|
# 5. Prepare latents.
|
||||||
|
latent_channels = self.transformer.config.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
latent_channels,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
# 6.1 Prepare micro-conditions.
|
||||||
|
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||||
|
if self.transformer.config.sample_size == 128:
|
||||||
|
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
||||||
|
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
||||||
|
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
||||||
|
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
resolution = torch.cat([resolution, resolution], dim=0)
|
||||||
|
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
||||||
|
|
||||||
|
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
||||||
|
|
||||||
|
# 7. Denoising loop
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
current_timestep = t
|
||||||
|
if not torch.is_tensor(current_timestep):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
# This would be a good case for the `match` statement (Python 3.10+)
|
||||||
|
is_mps = latent_model_input.device.type == "mps"
|
||||||
|
if isinstance(current_timestep, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
||||||
|
elif len(current_timestep.shape) == 0:
|
||||||
|
current_timestep = current_timestep[None].to(latent_model_input.device)
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
|
# predict noise model_output
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
latent_model_input,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
encoder_attention_mask=prompt_attention_mask,
|
||||||
|
timestep=current_timestep,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# learned sigma
|
||||||
|
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||||
|
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred
|
||||||
|
|
||||||
|
# compute previous image: x_t -> x_t-1
|
||||||
|
if num_inference_steps == 1:
|
||||||
|
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
|
||||||
|
else:
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||||
|
callback(step_idx, t, latents)
|
||||||
|
|
||||||
|
output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output
|
||||||
|
if not output_type == "latent":
|
||||||
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||||
|
if use_resolution_binning:
|
||||||
|
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||||
|
else:
|
||||||
|
image = latents
|
||||||
|
|
||||||
|
if not output_type == "latent":
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
# self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
return image
|
178
colossalai/inference/modeling/models/stablediffusion3.py
Normal file
178
colossalai/inference/modeling/models/stablediffusion3.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
||||||
|
|
||||||
|
from .diffusion import DiffusionPipe
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(@lry89757) temporarily image, please support more return output
|
||||||
|
@torch.no_grad()
|
||||||
|
def sd3_forward(
|
||||||
|
self: DiffusionPipe,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||||
|
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 28,
|
||||||
|
timesteps: List[int] = None,
|
||||||
|
guidance_scale: float = 7.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||||
|
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
clip_skip: Optional[int] = None,
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
):
|
||||||
|
height = height or self.default_sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
prompt_2,
|
||||||
|
prompt_3,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
negative_prompt_2=negative_prompt_2,
|
||||||
|
negative_prompt_3=negative_prompt_3,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||||
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._clip_skip = clip_skip
|
||||||
|
self._joint_attention_kwargs = joint_attention_kwargs
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_2=prompt_2,
|
||||||
|
prompt_3=prompt_3,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
negative_prompt_2=negative_prompt_2,
|
||||||
|
negative_prompt_3=negative_prompt_3,
|
||||||
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||||
|
device=device,
|
||||||
|
clip_skip=self.clip_skip,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.do_classifier_free_guidance:
|
||||||
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
|
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.transformer.config.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
pooled_projections=pooled_prompt_embeds,
|
||||||
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if self.do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents_dtype = latents.dtype
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
if latents.dtype != latents_dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||||
|
latents = latents.to(latents_dtype)
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||||
|
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||||
|
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
|
||||||
|
else:
|
||||||
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
|
|
||||||
|
image = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
return image
|
@ -1,16 +1,22 @@
|
|||||||
from .glide_llama import GlideLlamaModelPolicy
|
from .glide_llama import GlideLlamaModelPolicy
|
||||||
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
|
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
|
||||||
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
||||||
|
from .pixart_alpha import PixArtAlphaInferPolicy
|
||||||
|
from .stablediffusion3 import StableDiffusion3InferPolicy
|
||||||
|
|
||||||
model_policy_map = {
|
model_policy_map = {
|
||||||
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
||||||
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
|
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
|
||||||
"glide_llama": GlideLlamaModelPolicy,
|
"glide_llama": GlideLlamaModelPolicy,
|
||||||
|
"StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
|
||||||
|
"PixArtAlphaPipeline": PixArtAlphaInferPolicy,
|
||||||
}
|
}
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"NoPaddingLlamaModelInferPolicy",
|
"NoPaddingLlamaModelInferPolicy",
|
||||||
"NoPaddingBaichuanModelInferPolicy",
|
"NoPaddingBaichuanModelInferPolicy",
|
||||||
"GlideLlamaModelPolicy",
|
"GlideLlamaModelPolicy",
|
||||||
|
"StableDiffusion3InferPolicy",
|
||||||
|
"PixArtAlphaInferPolicy",
|
||||||
"model_polic_map",
|
"model_polic_map",
|
||||||
]
|
]
|
||||||
|
34
colossalai/inference/modeling/policy/pixart_alpha.py
Normal file
34
colossalai/inference/modeling/policy/pixart_alpha.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from colossalai.inference.config import RPC_PARAM
|
||||||
|
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||||
|
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
|
|
||||||
|
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
policy = {}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def preprocess(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def config_sanity_check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def to_rpc_param(self) -> str:
|
||||||
|
return __class__.__name__
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_rpc_param() -> "PixArtAlphaInferPolicy":
|
||||||
|
return PixArtAlphaInferPolicy()
|
34
colossalai/inference/modeling/policy/stablediffusion3.py
Normal file
34
colossalai/inference/modeling/policy/stablediffusion3.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from colossalai.inference.config import RPC_PARAM
|
||||||
|
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||||
|
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
policy = {}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def preprocess(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def config_sanity_check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def to_rpc_param(self) -> str:
|
||||||
|
return __class__.__name__
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_rpc_param() -> "StableDiffusion3InferPolicy":
|
||||||
|
return StableDiffusion3InferPolicy()
|
@ -2,6 +2,7 @@ import enum
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
|
from colossalai.inference.config import DiffusionGenerationConfig
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
logger = get_dist_logger(__name__)
|
||||||
@ -46,6 +47,17 @@ class RequestStatus(enum.Enum):
|
|||||||
return status == RequestStatus.WAITING
|
return status == RequestStatus.WAITING
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiffusionSequence:
|
||||||
|
"""
|
||||||
|
parameters for diffusion
|
||||||
|
"""
|
||||||
|
|
||||||
|
request_id: int
|
||||||
|
prompt: str
|
||||||
|
generation_config: DiffusionGenerationConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Sequence:
|
class Sequence:
|
||||||
"""Store information of input sequence.
|
"""Store information of input sequence.
|
||||||
|
@ -5,10 +5,12 @@ Utils for model inference
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
DIFFUSION_MODEL = "Diffusion Model"
|
||||||
|
LLM = "Large Language Model (LLM)"
|
||||||
|
UNKNOWN = "Unknown Model Type"
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
|
||||||
|
if isinstance(model_or_path, DiffusionPipeline):
|
||||||
|
return ModelType.DIFFUSION_MODEL
|
||||||
|
elif isinstance(model_or_path, nn.Module):
|
||||||
|
return ModelType.LLM
|
||||||
|
elif isinstance(model_or_path, str):
|
||||||
|
try:
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
|
return ModelType.LLM
|
||||||
|
except:
|
||||||
|
"""
|
||||||
|
model type is not `ModelType.LLM`
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
DiffusionPipeline.load_config(model_or_path)
|
||||||
|
return ModelType.DIFFUSION_MODEL
|
||||||
|
except:
|
||||||
|
"""
|
||||||
|
model type is not `ModelType.DIFFUSION_MODEL`
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
return ModelType.UNKNOWN
|
||||||
|
75
examples/inference/stable_diffusion/sd3_generation.py
Normal file
75
examples/inference/stable_diffusion/sd3_generation.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
|
||||||
|
from torch import bfloat16, float16, float32
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
|
||||||
|
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
|
||||||
|
|
||||||
|
# For Stable Diffusion 3, we'll use the following configuration
|
||||||
|
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
|
||||||
|
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
|
||||||
|
|
||||||
|
TORCH_DTYPE_MAP = {
|
||||||
|
"fp16": float16,
|
||||||
|
"fp32": float32,
|
||||||
|
"bf16": bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def infer(args):
|
||||||
|
# ==============================
|
||||||
|
# Launch colossalai, setup distributed environment
|
||||||
|
# ==============================
|
||||||
|
colossalai.launch_from_torch()
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Load model and tokenizer
|
||||||
|
# ==============================
|
||||||
|
model_path_or_name = args.model
|
||||||
|
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize InferenceEngine
|
||||||
|
# ==============================
|
||||||
|
coordinator.print_on_master(f"Initializing Inference Engine...")
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
dtype=args.dtype,
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
tp_size=args.tp_size,
|
||||||
|
use_cuda_kernel=args.use_cuda_kernel,
|
||||||
|
)
|
||||||
|
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Generation
|
||||||
|
# ==============================
|
||||||
|
coordinator.print_on_master(f"Generating...")
|
||||||
|
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
|
||||||
|
out.save("cat.jpg")
|
||||||
|
coordinator.print_on_master(out)
|
||||||
|
|
||||||
|
|
||||||
|
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
|
||||||
|
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# ==============================
|
||||||
|
# Parse Arguments
|
||||||
|
# ==============================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
|
||||||
|
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
|
||||||
|
parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
|
||||||
|
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
|
||||||
|
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
||||||
|
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
infer(args)
|
@ -23,3 +23,4 @@ rpyc==6.0.0
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn==0.29.0
|
uvicorn==0.29.0
|
||||||
galore_torch
|
galore_torch
|
||||||
|
diffusers==0.29.0
|
||||||
|
Loading…
Reference in New Issue
Block a user