[fx] add torchaudio test (#1369)

* [fx]add torchaudio test

* [fx]add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test and test patches

* Delete ~

* [fx] add patches and patches test

* [fx] add patches and patches test

* [fx] fix patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] merge upstream

* [fx] fix import errors
This commit is contained in:
Super Daniel
2022-07-27 11:03:14 +08:00
committed by GitHub
parent fb6f085907
commit be229217ce
18 changed files with 609 additions and 16 deletions

View File

@@ -108,6 +108,27 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
return torch.empty(final_shape, device="meta")
@meta_patched_function.register(torch.repeat_interleave)
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
"Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
shape = list(input.shape) if dim is not None else [input.numel()]
dim = dim if dim is not None else 0
dim = input.dim() + dim if dim < 0 else dim
if isinstance(repeats, int):
shape[dim] = shape[dim] * repeats
elif isinstance(repeats, torch.Tensor):
shape[dim] = repeats.sum()
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.Tensor.repeat_interleave)
def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
return torch_repeat_interleave(self, repeats, dim, output_size)
@meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None):
return torch.empty(input.shape, device='meta')