[hotfix] fx get comm size bugs (#1233)

* init a checkpoint dir

* [checkpoint]support resume for cosinewarmuplr

* [checkpoint]add unit test

* fix some bugs but still not OK

* fix bugs

* make it faster

* [checkpoint]support generalized scheduler

* polish

* [tensor] torch function return colotensor

* polish

* fix bugs

* remove debug info

* polish

* polish

* [tensor] test_model pass unittests

* polish

* [hotfix] fx get comm size bug

Co-authored-by: ZhaoYi1222 <zhaoyi9499@gmail.com>
This commit is contained in:
Jiarui Fang
2022-07-08 10:54:41 +08:00
committed by GitHub
parent 42ab36b762
commit 0e199d71e8
2 changed files with 6 additions and 8 deletions

View File

@@ -74,8 +74,9 @@ def run_1d_hybrid_tp(model_name):
continue
# print(name)
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p, pg)
# TODO(jiaruifang) has bug if open the following 2 comments
# if 'classifier' in name and 'weight' in name:
# init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
if 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
@@ -152,7 +153,6 @@ def run_1d_hybrid_tp(model_name):
# Test the overrided parameters() and named_parameters() member functions
@pytest.mark.skip
def test_model_parameters():
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
@@ -186,9 +186,8 @@ def test_model_parameters():
assert param_cnt == 2
@pytest.mark.skip
# @pytest.mark.skip
def test_colo_optimizer():
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(1)
@@ -323,7 +322,6 @@ def run_model_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("under development")
@rerun_if_address_is_in_use()
def test_model(world_size):
run_func = partial(run_model_dist, world_size=world_size, port=free_port())