mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-03 05:58:09 +00:00
* [pipeline inference] pipeline inference (#4492) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * fix CI * add cache clear * fix code error * fix typo * [Pipeline inference] Modify to tieweight (#4599) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * modify the way of saving newtokens * modify to tieweight * modify test * remove unused file * solve review * add docstring * [Pipeline inference] support llama pipeline inference (#4647) * support llama pipeline inference * remove tie weight operation * [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708) * add benchmark verbose * fix export tokens * fix benchmark verbose * add P2POp style to do p2p communication * modify schedule as p2p type when ppsize is 2 * remove unused code and add docstring * [Pipeline inference] Refactor code, add docsting, fix bug (#4790) * add benchmark script * update argparse * fix fp16 load * refactor code style * add docstring * polish code * fix test bug * [Pipeline inference] Add pipeline inference docs (#4817) * add readme doc * add a ico * Add performance * update table of contents * refactor code (#4873)
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
from copy import deepcopy
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import transformers
|
|
|
|
import colossalai
|
|
from colossalai.inference.pipeline.engine import PPInferEngine
|
|
from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
|
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
|
|
|
|
|
def data_gen():
|
|
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
|
|
inputs = data_gen()
|
|
for k, v in inputs.items():
|
|
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
|
new_shape = [1] * v.dim()
|
|
new_shape[0] = 16
|
|
inputs[k] = v.to('cuda').repeat(*new_shape)
|
|
|
|
|
|
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
|
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
|
|
engine = PPInferEngine(pp_size=pp_size,
|
|
model=model,
|
|
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
|
new_length=new_length,
|
|
micro_batch_size=micro_batch_size)
|
|
output = engine.inference([inputs])
|
|
if dist.get_rank() == 0:
|
|
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
|
|
|
|
|
|
@parameterize('pp_size', [4])
|
|
@parameterize('new_length', [4, 8, 16])
|
|
@parameterize('micro_batch_size', [1, 4])
|
|
@clear_cache_before_run()
|
|
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
|
pipeline_inference_test(pp_size, new_length, micro_batch_size)
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def check_pipeline_inference(rank, world_size, port):
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
run_pipeline_inference_test()
|
|
|
|
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
@clear_cache_before_run()
|
|
def test_pipeline_inference():
|
|
spawn(check_pipeline_inference, nprocs=4)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_pipeline_inference()
|