add bert for unitest and sharded model is not able to pass the bert case

This commit is contained in:
jiaruifang
2022-03-09 10:39:02 +08:00
committed by Frank Lee
parent 3d5d64bd10
commit 7977422aeb
6 changed files with 104 additions and 14 deletions

View File

@@ -17,8 +17,9 @@ from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
get_gradient_predivide_factor)
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor)
# from ._zero3_utils import cast_float_arguments, cast_tensor_to_fp16
class ShardedModelV2(nn.Module):
@@ -79,7 +80,8 @@ class ShardedModelV2(nn.Module):
self._require_backward_grad_sync: bool = True
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
# TODO args can be Long!
# args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs)
return outputs