mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[fx] patched torch.max and data movement operator (#1391)
* [fx] patched torch.max and data movement operator * polish code
This commit is contained in:
@@ -61,3 +61,22 @@ def test_repeat_interleave():
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=True,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
|
||||
def test_torch_max():
|
||||
data = torch.rand(4, 3)
|
||||
out = torch.max(data)
|
||||
patched_out = patched_function.torch_max(data)
|
||||
assert out.shape == patched_out.shape
|
||||
|
||||
data = torch.rand(4, 3, 2)
|
||||
out, idx = torch.max(data, dim=1)
|
||||
patched_out, patched_idx = patched_function.torch_max(data, dim=1)
|
||||
assert out.shape == patched_out.shape
|
||||
assert idx.shape == patched_idx.shape
|
||||
|
||||
data = torch.rand(4, 3, 2)
|
||||
out, idx = torch.max(data, dim=1, keepdim=True)
|
||||
patched_out, patched_idx = patched_function.torch_max(data, dim=1, keepdim=True)
|
||||
assert out.shape == patched_out.shape
|
||||
assert idx.shape == patched_idx.shape
|
||||
|
Reference in New Issue
Block a user