[inference] add reference and fix some bugs (#4937)

* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>
This commit is contained in:
Xu Kai
2023-10-20 13:39:34 +08:00
committed by GitHub
parent b8e770c832
commit 785802e809
7 changed files with 24 additions and 10 deletions

View File

@@ -8,7 +8,6 @@ from transformers import LlamaTokenizer
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
@@ -50,8 +49,6 @@ def run_llama_test(args):
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
)
init_to_get_rotary(model.model.model, base=10000)
model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True