mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 17:10:03 +00:00
[fx] add torchaudio test (#1369)
* [fx]add torchaudio test * [fx]add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test and test patches * Delete ~ * [fx] add patches and patches test * [fx] add patches and patches test * [fx] fix patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] merge upstream * [fx] fix import errors
This commit is contained in:
@@ -108,6 +108,27 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||
return torch.empty(final_shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.repeat_interleave)
|
||||
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
|
||||
assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
|
||||
"Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
|
||||
|
||||
shape = list(input.shape) if dim is not None else [input.numel()]
|
||||
dim = dim if dim is not None else 0
|
||||
dim = input.dim() + dim if dim < 0 else dim
|
||||
|
||||
if isinstance(repeats, int):
|
||||
shape[dim] = shape[dim] * repeats
|
||||
elif isinstance(repeats, torch.Tensor):
|
||||
shape[dim] = repeats.sum()
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.repeat_interleave)
|
||||
def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
|
||||
return torch_repeat_interleave(self, repeats, dim, output_size)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.roll)
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
Reference in New Issue
Block a user