mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[format] applied code formatting on changed files in pull request 4820 (#4886)
Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
committed by
GitHub
parent
c7aa319ba0
commit
486d06a2d5
@@ -1,9 +1,6 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
|
||||
import colossalai
|
||||
@@ -20,27 +17,29 @@ def data_gen():
|
||||
|
||||
inputs = data_gen()
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to('cuda').repeat(*new_shape)
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
|
||||
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
|
||||
engine = PPInferEngine(pp_size=pp_size,
|
||||
model=model,
|
||||
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
||||
new_length=new_length,
|
||||
micro_batch_size=micro_batch_size)
|
||||
engine = PPInferEngine(
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
||||
new_length=new_length,
|
||||
micro_batch_size=micro_batch_size,
|
||||
)
|
||||
output = engine.inference([inputs])
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
|
||||
|
||||
|
||||
@parameterize('pp_size', [4])
|
||||
@parameterize('new_length', [4, 8, 16])
|
||||
@parameterize('micro_batch_size', [1, 4])
|
||||
@parameterize("pp_size", [4])
|
||||
@parameterize("new_length", [4, 8, 16])
|
||||
@parameterize("micro_batch_size", [1, 4])
|
||||
@clear_cache_before_run()
|
||||
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
pipeline_inference_test(pp_size, new_length, micro_batch_size)
|
||||
@@ -48,7 +47,7 @@ def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
|
||||
|
||||
def check_pipeline_inference(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_pipeline_inference_test()
|
||||
|
||||
|
||||
@@ -59,5 +58,5 @@ def test_pipeline_inference():
|
||||
spawn(check_pipeline_inference, nprocs=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_pipeline_inference()
|
||||
|
Reference in New Issue
Block a user