From e481489aa6ef5131a507172e5d60274a6c87afa5 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 18 Nov 2022 14:19:40 +0800 Subject: [PATCH] [Gemini] MemtracerWrapper unittests (#1981) --- colossalai/gemini/ophooks/mem_trace_hook.py | 5 +++ tests/test_gemini/test_mem_tracer.py | 42 +++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 tests/test_gemini/test_mem_tracer.py diff --git a/colossalai/gemini/ophooks/mem_trace_hook.py b/colossalai/gemini/ophooks/mem_trace_hook.py index ed68d4597..49982b175 100644 --- a/colossalai/gemini/ophooks/mem_trace_hook.py +++ b/colossalai/gemini/ophooks/mem_trace_hook.py @@ -36,6 +36,11 @@ class MemTracerOpHook(BaseOpHook): p.grad = p.grad.to(dev) comm_volume += p.grad.numel() * p.grad.element_size() + for buf in module.buffers(): + if buf.device.type != dev: + buf.data = buf.data.to(dev) + comm_volume += buf.data.numel() * buf.data.element_size() + if dev == 'cuda': self._cur_model_data_vol = comm_volume diff --git a/tests/test_gemini/test_mem_tracer.py b/tests/test_gemini/test_mem_tracer.py new file mode 100644 index 000000000..c7700d9d7 --- /dev/null +++ b/tests/test_gemini/test_mem_tracer.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn + +import colossalai +from colossalai.gemini.memory_tracer import MemtracerWrapper +from tests.components_to_test.registry import non_distributed_component_funcs + + +def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + model.backward(loss) + + +def test_tracer(): + # reset the manager, in case that there exists memory information left + test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module'] + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + # init model on cpu + model = MemtracerWrapper(model_builder()) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + data = data.cuda() + label = label.cuda() + + run_fwd_bwd(model, data, label, criterion, False) + + # model._ophook_list[0].print_non_model_data() + + +if __name__ == '__main__': + test_tracer()