[format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
github-actions[bot]
2023-10-18 11:46:37 +08:00
committed by GitHub
parent c7aa319ba0
commit 486d06a2d5
13 changed files with 297 additions and 258 deletions

View File

@@ -1,9 +1,6 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
import colossalai
@@ -20,27 +17,29 @@ def data_gen():
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to('cuda').repeat(*new_shape)
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
engine = PPInferEngine(pp_size=pp_size,
model=model,
model_policy=GPT2LMHeadModelPipelinePolicy(),
new_length=new_length,
micro_batch_size=micro_batch_size)
engine = PPInferEngine(
pp_size=pp_size,
model=model,
model_policy=GPT2LMHeadModelPipelinePolicy(),
new_length=new_length,
micro_batch_size=micro_batch_size,
)
output = engine.inference([inputs])
if dist.get_rank() == 0:
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
@parameterize('pp_size', [4])
@parameterize('new_length', [4, 8, 16])
@parameterize('micro_batch_size', [1, 4])
@parameterize("pp_size", [4])
@parameterize("new_length", [4, 8, 16])
@parameterize("micro_batch_size", [1, 4])
@clear_cache_before_run()
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
pipeline_inference_test(pp_size, new_length, micro_batch_size)
@@ -48,7 +47,7 @@ def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_pipeline_inference_test()
@@ -59,5 +58,5 @@ def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=4)
if __name__ == '__main__':
if __name__ == "__main__":
test_pipeline_inference()