mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 14:49:24 +00:00
add bert for unitest and sharded model is not able to pass the bert case
This commit is contained in:
@@ -17,8 +17,9 @@ from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||
get_gradient_predivide_factor)
|
||||
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor)
|
||||
|
||||
# from ._zero3_utils import cast_float_arguments, cast_tensor_to_fp16
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
@@ -79,7 +80,8 @@ class ShardedModelV2(nn.Module):
|
||||
self._require_backward_grad_sync: bool = True
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
# TODO args can be Long!
|
||||
# args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
outputs = self.module(*args, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
Reference in New Issue
Block a user