[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,20 +1,21 @@
import math
from abc import ABC
from typing import Any, Iterable, List
from typing import List
from torch.utils._pytree import tree_map
class Chain:
def __init__(self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True):
def __init__(
self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True,
):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
@@ -37,9 +38,14 @@ class Chain:
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1))
return (
(len(self.ftime) == len(self))
and (len(self.btime) == len(self) + 1)
and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self))
and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1)
)
def __repr__(self):
chain_list = []
@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
@@ -109,9 +114,9 @@ class Forwards(Operation):
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
return sum(chain.ftime[self.index[0] : self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
return self.index[1] - self.index[0] + 1
def isForward(op):
@@ -132,7 +137,6 @@ class Backward(Operation):
class Loss(Operation):
def __init__(self):
pass
@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
class Sequence(list):
def __init__(self):
super().__init__()