mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -17,11 +17,11 @@ from tests.test_zero.test_legacy.common import CONFIG
|
||||
def exam_moe_checkpoint():
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = MoeModel(checkpoint=True)
|
||||
save_moe_model(model, 'temp_path.pth')
|
||||
save_moe_model(model, "temp_path.pth")
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
other_model = MoeModel(checkpoint=True)
|
||||
load_moe_model(other_model, 'temp_path.pth')
|
||||
load_moe_model(other_model, "temp_path.pth")
|
||||
|
||||
state_0 = model.state_dict()
|
||||
state_1 = other_model.state_dict()
|
||||
@@ -30,11 +30,11 @@ def exam_moe_checkpoint():
|
||||
assert torch.equal(u.data, v.data)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
os.remove('temp_path.pth')
|
||||
os.remove("temp_path.pth")
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
exam_moe_checkpoint()
|
||||
|
||||
@@ -46,5 +46,5 @@ def test_moe_checkpoint(world_size):
|
||||
spawn(_run_dist)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_moe_checkpoint(world_size=4)
|
||||
|
Reference in New Issue
Block a user