mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,20 +1,21 @@
|
||||
import math
|
||||
from abc import ABC
|
||||
from typing import Any, Iterable, List
|
||||
from typing import List
|
||||
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
class Chain:
|
||||
|
||||
def __init__(self,
|
||||
ftime: List[float],
|
||||
btime: List[float],
|
||||
x: List[int],
|
||||
xbar: List[int],
|
||||
ftmp: List[int],
|
||||
btmp: List[int],
|
||||
check_consistency: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
ftime: List[float],
|
||||
btime: List[float],
|
||||
x: List[int],
|
||||
xbar: List[int],
|
||||
ftmp: List[int],
|
||||
btmp: List[int],
|
||||
check_consistency: bool = True,
|
||||
):
|
||||
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
|
||||
See paper https://hal.inria.fr/hal-02352969 for details.
|
||||
|
||||
@@ -37,9 +38,14 @@ class Chain:
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
|
||||
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
|
||||
and (len(self.xbar) == len(self) + 1))
|
||||
return (
|
||||
(len(self.ftime) == len(self))
|
||||
and (len(self.btime) == len(self) + 1)
|
||||
and (len(self.x) == len(self) + 1)
|
||||
and (len(self.ftmp) == len(self))
|
||||
and (len(self.btmp) == len(self) + 1)
|
||||
and (len(self.xbar) == len(self) + 1)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
|
||||
|
||||
|
||||
class Forwards(Operation):
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.index = (start, end)
|
||||
|
||||
@@ -109,9 +114,9 @@ class Forwards(Operation):
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
|
||||
return sum(chain.ftime[self.index[0] : self.index[1] + 1])
|
||||
else:
|
||||
return (self.index[1] - self.index[0] + 1)
|
||||
return self.index[1] - self.index[0] + 1
|
||||
|
||||
|
||||
def isForward(op):
|
||||
@@ -132,7 +137,6 @@ class Backward(Operation):
|
||||
|
||||
|
||||
class Loss(Operation):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
|
||||
|
||||
|
||||
class Sequence(list):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
Reference in New Issue
Block a user