mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 13:42:12 +00:00
[inference] update examples and engine (#5073)
* update examples and engine * fix choices * update example
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .engine import CaiInferEngine
|
||||
from .engine import InferenceEngine
|
||||
from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy
|
||||
|
||||
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
|
||||
__all__ = ["InferenceEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
|
||||
|
@@ -1,3 +1,3 @@
|
||||
from .engine import CaiInferEngine
|
||||
from .engine import InferenceEngine
|
||||
|
||||
__all__ = ["CaiInferEngine"]
|
||||
__all__ = ["InferenceEngine"]
|
||||
|
@@ -3,7 +3,6 @@ from typing import Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
@@ -27,9 +26,9 @@ _supported_models = [
|
||||
]
|
||||
|
||||
|
||||
class CaiInferEngine:
|
||||
class InferenceEngine:
|
||||
"""
|
||||
CaiInferEngine is a class that handles the pipeline parallel inference.
|
||||
InferenceEngine is a class that handles the pipeline parallel inference.
|
||||
|
||||
Args:
|
||||
tp_size (int): the size of tensor parallelism.
|
||||
@@ -42,27 +41,6 @@ class CaiInferEngine:
|
||||
max_input_len (int): the maximum input length.
|
||||
max_output_len (int): the maximum output length.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from colossalai.inference import InferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
||||
# assume the model is infered with 2 pipeline stages
|
||||
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
|
||||
|
||||
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
|
||||
data = tokenizer(input, return_tensors='pt')
|
||||
output = inferengine.inference([data.to('cuda').data])
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -148,7 +126,7 @@ class CaiInferEngine:
|
||||
if quant == "gptq":
|
||||
self.gptq_manager.post_init_gptq_buffer(self.model)
|
||||
|
||||
def generate(self, input_list: Union[BatchEncoding, dict]):
|
||||
def generate(self, input_list: Union[list, dict]):
|
||||
"""
|
||||
Args:
|
||||
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
|
||||
@@ -157,11 +135,7 @@ class CaiInferEngine:
|
||||
out (list): a list of output data, each element is a list of token.
|
||||
timestamp (float): the time cost of the inference, only return when verbose is `True`.
|
||||
"""
|
||||
assert isinstance(
|
||||
input_list, (BatchEncoding, dict)
|
||||
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
|
||||
if isinstance(input_list, BatchEncoding):
|
||||
input_list = input_list.data
|
||||
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
||||
if self.verbose:
|
||||
return out, timestamp
|
||||
|
Reference in New Issue
Block a user