ColossalAI/tests/test_infer/test_pipeline_infer.py
Bin Jia 08a9f76b2f
[Pipeline Inference] Sync pipeline inference branch to main (#4820)
* [pipeline inference] pipeline inference (#4492)

* add pp stage manager as circle stage

* fix a bug when create process group

* add ppinfer basic framework

* add micro batch manager and support kvcache-pp gpt2 fwd

* add generate schedule

* use mb size to control mb number

* support generate with kv cache

* add output, remove unused code

* add test

* reuse shardformer to build model

* refactor some code and use the same attribute name of hf

* fix review and add test for generation

* remove unused file

* fix CI

* add cache clear

* fix code error

* fix typo

* [Pipeline inference] Modify to tieweight (#4599)

* add pp stage manager as circle stage

* fix a bug when create process group

* add ppinfer basic framework

* add micro batch manager and support kvcache-pp gpt2 fwd

* add generate schedule

* use mb size to control mb number

* support generate with kv cache

* add output, remove unused code

* add test

* reuse shardformer to build model

* refactor some code and use the same attribute name of hf

* fix review and add test for generation

* remove unused file

* modify the way of saving newtokens

* modify to tieweight

* modify test

* remove unused file

* solve review

* add docstring

* [Pipeline inference] support llama pipeline inference (#4647)

* support llama pipeline inference

* remove tie weight operation

* [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708)

* add benchmark verbose

* fix export tokens

* fix benchmark verbose

* add P2POp style to do p2p communication

* modify schedule as p2p type when ppsize is 2

* remove unused code and add docstring

* [Pipeline inference] Refactor code, add docsting, fix bug (#4790)

* add benchmark script

* update argparse

* fix fp16 load

* refactor code style

* add docstring

* polish code

* fix test bug

* [Pipeline inference] Add pipeline inference docs (#4817)

* add readme doc

* add a ico

* Add performance

* update table of contents

* refactor code (#4873)
2023-10-11 11:40:06 +08:00

64 lines
2.1 KiB
Python

from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
import colossalai
from colossalai.inference.pipeline.engine import PPInferEngine
from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to('cuda').repeat(*new_shape)
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
engine = PPInferEngine(pp_size=pp_size,
model=model,
model_policy=GPT2LMHeadModelPipelinePolicy(),
new_length=new_length,
micro_batch_size=micro_batch_size)
output = engine.inference([inputs])
if dist.get_rank() == 0:
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
@parameterize('pp_size', [4])
@parameterize('new_length', [4, 8, 16])
@parameterize('micro_batch_size', [1, 4])
@clear_cache_before_run()
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
pipeline_inference_test(pp_size, new_length, micro_batch_size)
torch.cuda.empty_cache()
def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_pipeline_inference_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=4)
if __name__ == '__main__':
test_pipeline_inference()