mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
[test] bert test in non-distributed way (#2074)
This commit is contained in:
parent
223332ff7e
commit
616ed91ecd
@ -68,16 +68,17 @@ def get_training_components():
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
is_distrbuted = torch.distributed.is_initialized()
|
||||||
trainloader = get_bert_data_loader(n_class=vocab_size,
|
trainloader = get_bert_data_loader(n_class=vocab_size,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
total_samples=10000,
|
total_samples=10000,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
is_distrbuted=True)
|
is_distrbuted=is_distrbuted)
|
||||||
testloader = get_bert_data_loader(n_class=vocab_size,
|
testloader = get_bert_data_loader(n_class=vocab_size,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
total_samples=10000,
|
total_samples=10000,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
is_distrbuted=True)
|
is_distrbuted=is_distrbuted)
|
||||||
|
|
||||||
criterion = None
|
criterion = None
|
||||||
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||||
|
@ -21,14 +21,15 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
|
|||||||
model.backward(loss)
|
model.backward(loss)
|
||||||
|
|
||||||
|
|
||||||
def run_param_wrapper_testing():
|
def test_runtime_mem_tracer():
|
||||||
test_models = ['simple_net', 'repeated_computed_layers', 'nested_model']
|
test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model']
|
||||||
|
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||||
|
|
||||||
with ColoInitContext(device=torch.device('cpu')):
|
with ColoInitContext(device=torch.device('cpu')):
|
||||||
model = model_builder(checkpoint=False)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
model_bk = deepcopy(model)
|
model_bk = deepcopy(model)
|
||||||
runtime_mem_tracer = RuntimeMemTracer(model)
|
runtime_mem_tracer = RuntimeMemTracer(model)
|
||||||
@ -52,4 +53,4 @@ def run_param_wrapper_testing():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_param_wrapper_testing()
|
test_runtime_mem_tracer()
|
||||||
|
Loading…
Reference in New Issue
Block a user