[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -25,13 +25,14 @@ class StatefulTensor(object):
https://arxiv.org/abs/2108.05818
"""
# Global Stateful Tensor Manager
GST_MGR = GeminiMemoryManager(TensorState)
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = None
self._payload_size = 0 # byte size of current payload
self._payload_size = 0 # byte size of current payload
StatefulTensor.GST_MGR.register_new_instance()
@@ -47,7 +48,7 @@ class StatefulTensor(object):
def data_ptr(self):
if self._payload is None:
return 0 # if a tensor has no storage, 0 should be returned
return 0 # if a tensor has no storage, 0 should be returned
return self._payload.data_ptr()
def set_null(self) -> None:
@@ -80,7 +81,7 @@ class StatefulTensor(object):
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
if not isinstance(device, torch.device):
to_device = torch.device('cuda', device)
to_device = torch.device("cuda", device)
else:
to_device = device
@@ -97,7 +98,6 @@ class StatefulTensor(object):
self._payload.view(-1).copy_(tensor.view(-1))
def payload_reset(self, tensor) -> None:
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
if self.payload is not None:
@@ -168,8 +168,7 @@ class StatefulTensor(object):
self._payload_size = 0
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
"""Update global manager when changing the state of a tensor
"""
"""Update global manager when changing the state of a tensor"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
device_type = self.device.type
@@ -189,8 +188,7 @@ class StatefulTensor(object):
manager.total_mem[device_type] -= size
def __trans_device_update(self, from_type: str, to_type: str):
"""Update global manager when changing the device of a tensor
"""
"""Update global manager when changing the device of a tensor"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
state = self.state