mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -25,13 +25,14 @@ class StatefulTensor(object):
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
# Global Stateful Tensor Manager
|
||||
GST_MGR = GeminiMemoryManager(TensorState)
|
||||
|
||||
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = None
|
||||
self._payload_size = 0 # byte size of current payload
|
||||
self._payload_size = 0 # byte size of current payload
|
||||
|
||||
StatefulTensor.GST_MGR.register_new_instance()
|
||||
|
||||
@@ -47,7 +48,7 @@ class StatefulTensor(object):
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
return 0 # if a tensor has no storage, 0 should be returned
|
||||
return 0 # if a tensor has no storage, 0 should be returned
|
||||
return self._payload.data_ptr()
|
||||
|
||||
def set_null(self) -> None:
|
||||
@@ -80,7 +81,7 @@ class StatefulTensor(object):
|
||||
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
|
||||
|
||||
if not isinstance(device, torch.device):
|
||||
to_device = torch.device('cuda', device)
|
||||
to_device = torch.device("cuda", device)
|
||||
else:
|
||||
to_device = device
|
||||
|
||||
@@ -97,7 +98,6 @@ class StatefulTensor(object):
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def payload_reset(self, tensor) -> None:
|
||||
|
||||
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
|
||||
|
||||
if self.payload is not None:
|
||||
@@ -168,8 +168,7 @@ class StatefulTensor(object):
|
||||
self._payload_size = 0
|
||||
|
||||
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
|
||||
"""Update global manager when changing the state of a tensor
|
||||
"""
|
||||
"""Update global manager when changing the state of a tensor"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
device_type = self.device.type
|
||||
@@ -189,8 +188,7 @@ class StatefulTensor(object):
|
||||
manager.total_mem[device_type] -= size
|
||||
|
||||
def __trans_device_update(self, from_type: str, to_type: str):
|
||||
"""Update global manager when changing the device of a tensor
|
||||
"""
|
||||
"""Update global manager when changing the device of a tensor"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
state = self.state
|
||||
|
Reference in New Issue
Block a user