[Feat]Tensor Model Parallel Support For Inference (#5563)

* tensor parallel support naive source

* [fix]precision, model load and refactor the framework

* add tp unit test

* docstring

* fix do_sample
This commit is contained in:
Runyu Lu
2024-04-18 16:56:46 +08:00
committed by GitHub
parent be396ad6cc
commit e37ee2fb65
8 changed files with 640 additions and 150 deletions

View File

@@ -3,24 +3,27 @@ import random
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_engine=False, prompt_template=None):
def check_inference_engine(use_engine=False, prompt_template=None, do_sample=True, policy=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
@@ -36,13 +39,19 @@ def check_inference_engine(use_engine=False, prompt_template=None):
]
output_len = 38
do_sample = True
do_sample = do_sample
top_p = 0.5
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
dtype="fp32",
use_cuda_kernel=True,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
@@ -69,20 +78,14 @@ def check_inference_engine(use_engine=False, prompt_template=None):
return outputs
@parameterize("prompt_template", [None, "llama"])
def check_output_consistency(prompt_template):
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
# clear singleton flash decoding tensors
FDIntermTensors._instances = {}
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
return result_list[0]
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
def check_spec_dec(num_layers, max_length):
torch.manual_seed(123)
@@ -152,16 +155,47 @@ def check_spec_dec(num_layers, max_length):
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()
check_spec_dec()
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
@parameterize("prompt_template", [None, "llama"])
@parameterize("do_sample", [False])
def test_tp_engine(prompt_template, do_sample):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": NoPaddingLlamaModelInferPolicy(),
}
kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None}
colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
def test_spec_dec(num_layers, max_length):
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
spawn(run_dist, 1)
test_tp_engine()
test_spec_dec()
if __name__ == "__main__":