mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[fx] add torchaudio test (#1369)
* [fx]add torchaudio test * [fx]add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test and test patches * Delete ~ * [fx] add patches and patches test * [fx] add patches and patches test * [fx] fix patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] merge upstream * [fx] fix import errors
This commit is contained in:
@@ -108,6 +108,27 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||
return torch.empty(final_shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.repeat_interleave)
|
||||
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
|
||||
assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
|
||||
"Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
|
||||
|
||||
shape = list(input.shape) if dim is not None else [input.numel()]
|
||||
dim = dim if dim is not None else 0
|
||||
dim = input.dim() + dim if dim < 0 else dim
|
||||
|
||||
if isinstance(repeats, int):
|
||||
shape[dim] = shape[dim] * repeats
|
||||
elif isinstance(repeats, torch.Tensor):
|
||||
shape[dim] = repeats.sum()
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.repeat_interleave)
|
||||
def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
|
||||
return torch_repeat_interleave(self, repeats, dim, output_size)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.roll)
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
@@ -3,4 +3,5 @@ from .convolution import *
|
||||
from .embedding import *
|
||||
from .linear import *
|
||||
from .normalization import *
|
||||
from .pooling import *
|
||||
from .pooling import *
|
||||
from .rnn import *
|
@@ -7,5 +7,6 @@ from ..registry import meta_patched_module
|
||||
@meta_patched_module.register(torch.nn.GELU)
|
||||
@meta_patched_module.register(torch.nn.Tanh)
|
||||
@meta_patched_module.register(torch.nn.ReLU6)
|
||||
@meta_patched_module.register(torch.nn.PReLU)
|
||||
def torch_nn_non_linear_act(self, input):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
@@ -55,3 +55,60 @@ def torch_nn_conv3d(self, input):
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
@meta_patched_module.register(torch.nn.ConvTranspose1d)
|
||||
def torch_nn_convtranspose1d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
|
||||
l_in = input.shape[-1]
|
||||
c_out = self.out_channels
|
||||
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] +
|
||||
self.dilation[0] * (self.kernel_size[0] - 1) +
|
||||
self.output_padding[0] + 1)
|
||||
result_shape = input.shape[:-2] + (
|
||||
c_out,
|
||||
l_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
@meta_patched_module.register(torch.nn.ConvTranspose2d)
|
||||
def torch_nn_convtranspose2d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
h_in, w_in = input.shape[-2:]
|
||||
c_out = self.out_channels
|
||||
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] +
|
||||
self.dilation[0] * (self.kernel_size[0] - 1) +
|
||||
self.output_padding[0] + 1)
|
||||
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] +
|
||||
self.dilation[1] * (self.kernel_size[1] - 1) +
|
||||
self.output_padding[1] + 1)
|
||||
result_shape = input.shape[:-3] + (
|
||||
c_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
@meta_patched_module.register(torch.nn.ConvTranspose3d)
|
||||
def torch_nn_convtranspose3d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
|
||||
d_in, h_in, w_in = input.shape[-3:]
|
||||
c_out = self.out_channels
|
||||
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] +
|
||||
self.dilation[0] * (self.kernel_size[0] - 1) +
|
||||
self.output_padding[0] + 1)
|
||||
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] +
|
||||
self.dilation[1] * (self.kernel_size[1] - 1) +
|
||||
self.output_padding[1] + 1)
|
||||
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] +
|
||||
self.dilation[2] * (self.kernel_size[2] - 1) +
|
||||
self.output_padding[2] + 1)
|
||||
result_shape = input.shape[:-4] + (
|
||||
c_out,
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
@@ -6,4 +6,4 @@ from ..registry import meta_patched_module
|
||||
def torch_nn_linear(self, input):
|
||||
last_dim = input.shape[-1]
|
||||
assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
|
||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||
|
14
colossalai/fx/tracer/meta_patch/patched_module/rnn.py
Normal file
14
colossalai/fx/tracer/meta_patch/patched_module/rnn.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from ..registry import meta_patched_module
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.GRU)
|
||||
@meta_patched_module.register(torch.nn.RNN)
|
||||
def torch_nn_rnn(self, input, hx):
|
||||
assert input.shape[
|
||||
-1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch'
|
||||
assert hx.shape[
|
||||
-1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch'
|
||||
d = 2 if self.bidirectional else 1
|
||||
return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx
|
@@ -27,7 +27,7 @@ def save_checkpoint(dire: str,
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
|
@@ -34,7 +34,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
@@ -54,9 +54,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
|
||||
pg=colo_tensor.get_process_group(),
|
||||
compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor = ColoTensor(
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
|
Reference in New Issue
Block a user