mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .sharded_param import ShardedParamV2
|
||||
from .sharded_tensor import ShardedTensor
|
||||
|
||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
||||
__all__ = ["ShardedTensor", "ShardedParamV2"]
|
||||
|
@@ -19,7 +19,6 @@ def get_empty_tensor(device: torch.device, dtype: torch.dtype):
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
@@ -36,8 +35,7 @@ class ShardedParamV2(object):
|
||||
self.set_data_none()
|
||||
|
||||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||
"""returns stateful tensors kept by this class.
|
||||
"""
|
||||
"""returns stateful tensors kept by this class."""
|
||||
return [self._sharded_data_tensor]
|
||||
|
||||
def set_data_none(self):
|
||||
|
@@ -4,7 +4,6 @@ from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, Tensor
|
||||
|
||||
|
||||
class ShardedTensor(StatefulTensor):
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||||
r"""
|
||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||
|
Reference in New Issue
Block a user