diff --git a/colossalai/fx/tracer/meta_patch/patched_function.py b/colossalai/fx/tracer/meta_patch/patched_function.py index a9ddda8e8..dd6312ccb 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function.py @@ -218,3 +218,8 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None): concatenated_dim = sum(shape[dim] for shape in shapes) final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] return torch.empty(final_shape, device="meta") + + +@meta_patched_function.register(torch.roll) +def torch_roll(input, shifts, dims=None): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module.py index 2e2cedfe2..787e7e68b 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module.py +++ b/colossalai/fx/tracer/meta_patch/patched_module.py @@ -249,6 +249,34 @@ def torch_nn_maxpool3d(self, input): return torch.empty(result_shape, device='meta') +@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d) +def torch_nn_adapative_pooling_1d(self, input): + result_shape = input.shape[:-1] + (self.output_size,) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d) +def torch_nn_adapative_pooling_2d(self, input): + result_shape = input.shape[:-2] + ( + self.output_size, + self.output_size, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d) +def torch_nn_adapative_pooling_3d(self, input): + result_shape = input.shape[:-3] + ( + self.output_size, + self.output_size, + self.output_size, + ) + return torch.empty(result_shape, device='meta') + + @meta_patched_module.register(torch.nn.ReLU) @meta_patched_module.register(torch.nn.ReLU6) def torch_nn_func_relu(self, input): diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index b69900d0b..7df0b2e6c 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -63,14 +63,8 @@ def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True MODEL_LIST_WITH_CONTROL_FLOW = [ - tm.convnext.convnext_base, - tm.vgg.vgg11, - tm.dpn.dpn68, - tm.densenet.densenet121, - tm.rexnet.rexnet_100, - - # not traceable - # tm.swin_transformer.swin_base_patch4_window7_224 + tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, + tm.swin_transformer.swin_base_patch4_window7_224 ] tracer = ColoTracer()