[fx] add torchaudio test (#1369)

* [fx]add torchaudio test

* [fx]add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test and test patches

* Delete ~

* [fx] add patches and patches test

* [fx] add patches and patches test

* [fx] fix patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] merge upstream

* [fx] fix import errors
This commit is contained in:
Super Daniel
2022-07-27 11:03:14 +08:00
committed by GitHub
parent fb6f085907
commit be229217ce
18 changed files with 609 additions and 16 deletions

View File

@@ -108,6 +108,27 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
return torch.empty(final_shape, device="meta")
@meta_patched_function.register(torch.repeat_interleave)
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
"Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
shape = list(input.shape) if dim is not None else [input.numel()]
dim = dim if dim is not None else 0
dim = input.dim() + dim if dim < 0 else dim
if isinstance(repeats, int):
shape[dim] = shape[dim] * repeats
elif isinstance(repeats, torch.Tensor):
shape[dim] = repeats.sum()
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.Tensor.repeat_interleave)
def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
return torch_repeat_interleave(self, repeats, dim, output_size)
@meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None):
return torch.empty(input.shape, device='meta')

View File

@@ -3,4 +3,5 @@ from .convolution import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .pooling import *
from .rnn import *

View File

@@ -7,5 +7,6 @@ from ..registry import meta_patched_module
@meta_patched_module.register(torch.nn.GELU)
@meta_patched_module.register(torch.nn.Tanh)
@meta_patched_module.register(torch.nn.ReLU6)
@meta_patched_module.register(torch.nn.PReLU)
def torch_nn_non_linear_act(self, input):
return torch.empty(input.shape, device='meta')

View File

@@ -55,3 +55,60 @@ def torch_nn_conv3d(self, input):
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.ConvTranspose1d)
def torch_nn_convtranspose1d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
l_in = input.shape[-1]
c_out = self.out_channels
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] +
self.dilation[0] * (self.kernel_size[0] - 1) +
self.output_padding[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.ConvTranspose2d)
def torch_nn_convtranspose2d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] +
self.dilation[0] * (self.kernel_size[0] - 1) +
self.output_padding[0] + 1)
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] +
self.dilation[1] * (self.kernel_size[1] - 1) +
self.output_padding[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.ConvTranspose3d)
def torch_nn_convtranspose3d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] +
self.dilation[0] * (self.kernel_size[0] - 1) +
self.output_padding[0] + 1)
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] +
self.dilation[1] * (self.kernel_size[1] - 1) +
self.output_padding[1] + 1)
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] +
self.dilation[2] * (self.kernel_size[2] - 1) +
self.output_padding[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')

View File

@@ -6,4 +6,4 @@ from ..registry import meta_patched_module
def torch_nn_linear(self, input):
last_dim = input.shape[-1]
assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")

View File

@@ -0,0 +1,14 @@
import torch
from ..registry import meta_patched_module
from typing import Optional
@meta_patched_module.register(torch.nn.GRU)
@meta_patched_module.register(torch.nn.RNN)
def torch_nn_rnn(self, input, hx):
assert input.shape[
-1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch'
assert hx.shape[
-1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch'
d = 2 if self.bidirectional else 1
return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx

View File

@@ -27,7 +27,7 @@ def save_checkpoint(dire: str,
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model_state.items():
if isinstance(v, ColoTensor):
gather_tensor(v) # gather shared tensors to rank0
gather_tensor(v) # gather shared tensors to rank0
# don't recover tensors in rank0, since the dict is only a copy of model
if rank == 0:

View File

@@ -34,7 +34,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
dist.barrier()
if dist.get_rank() == 0:
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
@@ -54,9 +54,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
if dist.get_rank() == 0:
colo_tensor.set_dist_spec(dist_spec)
else:
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
pg=colo_tensor.get_process_group(),
compute_attr=colo_tensor.compute_spec))
rep_tensor = ColoTensor(
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
rep_tensor.set_dist_spec(dist_spec)
with torch.no_grad():
colo_tensor.data.copy_(rep_tensor.data)