[inference] update examples and engine (#5073)

* update examples and engine

* fix choices

* update example
This commit is contained in:
Xu Kai
2023-11-20 19:44:52 +08:00
committed by GitHub
parent 0c7d8bebd5
commit fb103cfd6e
12 changed files with 107 additions and 273 deletions

View File

@@ -1,4 +1,4 @@
from .engine import CaiInferEngine
from .engine import InferenceEngine
from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
__all__ = ["InferenceEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]

View File

@@ -1,3 +1,3 @@
from .engine import CaiInferEngine
from .engine import InferenceEngine
__all__ = ["CaiInferEngine"]
__all__ = ["InferenceEngine"]

View File

@@ -3,7 +3,6 @@ from typing import Union
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import logging
from colossalai.cluster import ProcessGroupMesh
@@ -27,9 +26,9 @@ _supported_models = [
]
class CaiInferEngine:
class InferenceEngine:
"""
CaiInferEngine is a class that handles the pipeline parallel inference.
InferenceEngine is a class that handles the pipeline parallel inference.
Args:
tp_size (int): the size of tensor parallelism.
@@ -42,27 +41,6 @@ class CaiInferEngine:
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.
Example:
```python
from colossalai.inference import InferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
colossalai.launch_from_torch(config={})
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference([data.to('cuda').data])
```
"""
def __init__(
@@ -148,7 +126,7 @@ class CaiInferEngine:
if quant == "gptq":
self.gptq_manager.post_init_gptq_buffer(self.model)
def generate(self, input_list: Union[BatchEncoding, dict]):
def generate(self, input_list: Union[list, dict]):
"""
Args:
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
@@ -157,11 +135,7 @@ class CaiInferEngine:
out (list): a list of output data, each element is a list of token.
timestamp (float): the time cost of the inference, only return when verbose is `True`.
"""
assert isinstance(
input_list, (BatchEncoding, dict)
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
if isinstance(input_list, BatchEncoding):
input_list = input_list.data
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp