diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index a061e0ce1..02fa07e2c 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -40,20 +40,16 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.allclose( - non_fx_out[0], fx_out[0], atol=1e-4 - ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[0] - fx_out[0]) - ) - assert torch.allclose( - non_fx_out[1], fx_out[1], atol=1e-4 - ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[1] - fx_out[1]) - ) + assert torch.allclose(non_fx_out[0], fx_out[0], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[0] - fx_out[0])) + assert torch.allclose(non_fx_out[1], fx_out[1], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[1] - fx_out[1])) def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + # launch colossalai colossalai.launch( config={}, rank=rank, @@ -76,18 +72,14 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): "pair": pair.to(torch.device("meta")), }, ) - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace interp = MetaInfoProp(gm_prop) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") - ) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) # now run it twice to get meta info in graph module, not necessary gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") - ) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) graph.set_codegen(codegen) diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_autochunk_search.py index 537bf4f41..371fce64f 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_autochunk_search.py @@ -23,7 +23,8 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): if msa_len == 32 and pair_len == 64: if max_memory is None: - target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), (161, 166), (198, 203), (6, 69)] + target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), + (161, 166), (198, 203), (6, 69)] elif max_memory == 20: target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)] elif max_memory == 25: @@ -36,24 +37,19 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): raise NotImplementedError() assert len(found_regions) == len( - target_regions - ), "len of found regions %s doesn't equal len of target regions %s" % ( - str(found_regions), - str(target_regions), - ) + target_regions), "len of found regions %s doesn't equal len of target regions %s" % ( + str(found_regions), + str(target_regions), + ) for region in target_regions: - assert ( - region in found_regions - ), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % ( + assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % ( str(region), msa_len, pair_len, max_memory, ) for region in found_regions: - assert ( - region in target_regions - ), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( + assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( str(region), msa_len, pair_len, @@ -62,7 +58,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): def _test_autochunk_search(rank, msa_len, pair_len, max_memory): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + # launch colossalai colossalai.launch( config={}, rank=rank, @@ -77,11 +73,9 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace interp = MetaInfoProp(gm_prop) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") - ) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) chunk_infos = codegen.chunk_infos