mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[fx] add torchaudio test (#1369)
* [fx]add torchaudio test * [fx]add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test and test patches * Delete ~ * [fx] add patches and patches test * [fx] add patches and patches test * [fx] fix patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] merge upstream * [fx] fix import errors
This commit is contained in:
63
tests/test_fx/test_tracer/test_patched_op.py
Normal file
63
tests/test_fx/test_tracer/test_patched_op.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from colossalai.fx.tracer.meta_patch import patched_function
|
||||
from functools import partial
|
||||
|
||||
|
||||
def _run(data, patch_fn):
|
||||
try:
|
||||
output = patch_fn(data)
|
||||
return output
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
def _assert_output_shape(data, patch_fn, expect_exception, output_shape):
|
||||
output = _run(data, patch_fn)
|
||||
|
||||
if expect_exception:
|
||||
assert isinstance(output, AssertionError)
|
||||
else:
|
||||
assert not isinstance(output, Exception)
|
||||
assert output.is_meta
|
||||
assert output.shape == output_shape
|
||||
|
||||
|
||||
def test_repeat_interleave():
|
||||
patch_fn = patched_function.torch_repeat_interleave
|
||||
|
||||
# examples from https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html
|
||||
data = torch.tensor([1, 2, 3])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=2)
|
||||
repeat_interleave = partial(patch_fn, repeats=2)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=3, dim=1)
|
||||
repeat_interleave = partial(patch_fn, repeats=3, dim=1)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1)
|
||||
repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0)
|
||||
repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=True,
|
||||
output_shape=materialized_output.shape)
|
Reference in New Issue
Block a user