mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -6,25 +6,26 @@ from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
|
||||
fp16=dict(mode=None),
|
||||
clip_grad_norm=1.0)
|
||||
CONFIG = dict(
|
||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0
|
||||
)
|
||||
|
||||
|
||||
@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers'])
|
||||
@parameterize('amp_mode', [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
|
||||
@parameterize("model_name", ["repeated_computed_layers", "resnet18", "repeated_computed_layers"])
|
||||
@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
|
||||
def run_train(model_name, amp_mode):
|
||||
# FIXME: test bert
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
gpc.config.fp16['mode'] = amp_mode
|
||||
gpc.config.fp16["mode"] = amp_mode
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
model = model_builder(checkpoint=False)
|
||||
engine, train_dataloader, *args = colossalai.legacy.initialize(model=model,
|
||||
optimizer=optimizer_class(model.parameters(),
|
||||
lr=1e-3),
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
engine, train_dataloader, *args = colossalai.legacy.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer_class(model.parameters(), lr=1e-3),
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
)
|
||||
|
||||
try:
|
||||
engine.train()
|
||||
@@ -49,12 +50,9 @@ def run_train(model_name, amp_mode):
|
||||
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.legacy.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
colossalai.legacy.launch(
|
||||
config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl"
|
||||
)
|
||||
run_train()
|
||||
|
||||
|
||||
@@ -64,5 +62,5 @@ def test_engine():
|
||||
spawn(run_engine, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_engine()
|
||||
|
Reference in New Issue
Block a user