[fx] patched torch.max and data movement operator (#1391)

* [fx] patched torch.max and data movement operator

* polish code
This commit is contained in:
Frank Lee
2022-08-01 15:31:50 +08:00
committed by GitHub
parent db89600cf2
commit 7d6293927f
2 changed files with 52 additions and 0 deletions

View File

@@ -61,3 +61,22 @@ def test_repeat_interleave():
patch_fn=repeat_interleave,
expect_exception=True,
output_shape=materialized_output.shape)
def test_torch_max():
data = torch.rand(4, 3)
out = torch.max(data)
patched_out = patched_function.torch_max(data)
assert out.shape == patched_out.shape
data = torch.rand(4, 3, 2)
out, idx = torch.max(data, dim=1)
patched_out, patched_idx = patched_function.torch_max(data, dim=1)
assert out.shape == patched_out.shape
assert idx.shape == patched_idx.shape
data = torch.rand(4, 3, 2)
out, idx = torch.max(data, dim=1, keepdim=True)
patched_out, patched_idx = patched_function.torch_max(data, dim=1, keepdim=True)
assert out.shape == patched_out.shape
assert idx.shape == patched_idx.shape