mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-02 04:15:26 +00:00
polish code
This commit is contained in:
parent
d271f2596b
commit
5663616921
@ -18,8 +18,7 @@ from torch.distributed import ProcessGroup
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor)
|
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor)
|
||||||
|
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16)
|
||||||
# from ._zero3_utils import cast_float_arguments, cast_tensor_to_fp16
|
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
@ -80,8 +79,7 @@ class ShardedModelV2(nn.Module):
|
|||||||
self._require_backward_grad_sync: bool = True
|
self._require_backward_grad_sync: bool = True
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
# TODO args can be Long!
|
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||||
# args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user