[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -25,7 +25,7 @@ rpc_is_initialized = _is_current_rpc_agent_set
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
model.eval()
tracer = ColoTracer()
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
meta_args = {k: v.to("meta") for k, v in data_kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
annotated_model = balanced_split_pass(gm, stage_num)
@@ -33,7 +33,7 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
setattr(submodule, "_topo", topo)
return split_submodules[pp_rank + 1]
@@ -47,11 +47,11 @@ def run_master(model_cls, world_size, forward_only):
torch.manual_seed(100)
epoch = 3
device = 'cuda'
device = "cuda"
stage_num = world_size
chunk = 1
num_microbatches = 8
use_checkpoint = 'store_true'
use_checkpoint = "store_true"
if model_cls == MLP:
@@ -92,29 +92,26 @@ def run_master(model_cls, world_size, forward_only):
checkpoint=use_checkpoint,
)
if not forward_only:
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
engine.initialize_optimizer(getattr(torch.optim, "SGD"), lr=1e-3)
for _ in range(epoch):
input_x = torch.randn((batch_size, dim), device=device)
input_y = torch.randn((batch_size, dim), device=device)
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
logits = engine.forward_backward({"x": input_x, "y": input_y}, labels=labels, forward_only=forward_only)
def run_worker(rank, world_size, port, model_cls, forward_only, master_func):
master_addr = 'localhost'
master_addr = "localhost"
master_port = 29020
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
disable_existing_loggers()
launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False)
ppg.set_global_info(rank=rank,
world_size=world_size,
dp_degree=1,
tp_degree=1,
num_worker_threads=128,
device='cuda')
launch(dict(), rank, world_size, master_addr, master_port, "nccl", verbose=False)
ppg.set_global_info(
rank=rank, world_size=world_size, dp_degree=1, tp_degree=1, num_worker_threads=128, device="cuda"
)
# in rpc mode, only rank 0 is needed to be coded
if rank == 0:
@@ -125,8 +122,8 @@ def run_worker(rank, world_size, port, model_cls, forward_only, master_func):
@pytest.mark.skip("skip due to CI torch version 1.11")
@parameterize('model_cls', [MLP, DAG_MLP])
@parameterize('forward_only', [True, False])
@parameterize("model_cls", [MLP, DAG_MLP])
@parameterize("forward_only", [True, False])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pp_middleware_fwd(model_cls, forward_only):