mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from rpc_test_utils import RpcTestModel, parse_args, rpc_run
|
||||
from torch import autograd, nn
|
||||
from torch.optim import SGD, Adam, Optimizer, RMSprop
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
|
||||
# global variable for model created
|
||||
@@ -36,12 +36,14 @@ def run_master(args):
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
checkpoint=use_checkpoint)
|
||||
engine = OneFOneBPipelineEngine(
|
||||
partition_fn=partition,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
checkpoint=use_checkpoint,
|
||||
)
|
||||
|
||||
engine.initialize_optimizer(optimizer_class, lr=lr)
|
||||
|
||||
@@ -59,7 +61,8 @@ def run_master(args):
|
||||
|
||||
# compute forward result and backward grad of parameters just in rank_0
|
||||
test_model = nn.Sequential(
|
||||
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
|
||||
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]
|
||||
).to(device)
|
||||
optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr)
|
||||
input_sample = input_sample.requires_grad_()
|
||||
out_val = test_model(input_sample).sum()
|
||||
|
Reference in New Issue
Block a user