mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
from torch.fx import Graph, Node
|
||||
@@ -69,8 +68,8 @@ class GraphInfo:
|
||||
|
||||
|
||||
def is_phase(n: Node, phase: Phase) -> bool:
|
||||
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
|
||||
return n.meta['phase'] == phase
|
||||
assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!"
|
||||
return n.meta["phase"] == phase
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@@ -103,9 +102,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
peak_mem = 0
|
||||
for k, v in deps.items():
|
||||
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
|
||||
peak_mem += activation_size(k.meta['saved_tensor'])
|
||||
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
|
||||
peak_mem -= activation_size(k.meta['saved_tensor'])
|
||||
peak_mem += activation_size(k.meta["saved_tensor"])
|
||||
if v <= float("-inf") and is_phase(k, Phase.FORWARD):
|
||||
peak_mem -= activation_size(k.meta["saved_tensor"])
|
||||
return peak_mem
|
||||
|
||||
# deps is used to track all the memory dependencies of the graph.
|
||||
@@ -123,19 +122,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||
# the node, `fwd_mem_tmp` can be freed.
|
||||
if is_phase(n, Phase.PLACEHOLDER):
|
||||
graph_info.fwd_in += n.meta['saved_tensor']
|
||||
graph_info.fwd_in += n.meta["saved_tensor"]
|
||||
if is_phase(n, Phase.FORWARD):
|
||||
graph_info.fwd_tmp += n.meta['saved_tensor']
|
||||
graph_info.fwd_tmp += n.meta["saved_tensor"]
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||
else:
|
||||
# TODO: some of the bwd_mem_out might be model parameters.
|
||||
# basically a backward node without user is a `grad_out` node
|
||||
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
|
||||
graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"])
|
||||
for input_n in n.all_input_nodes:
|
||||
if input_n in deps:
|
||||
deps[input_n] -= 1
|
||||
if deps[input_n] <= 0:
|
||||
deps[input_n] = float('-inf')
|
||||
deps[input_n] = float("-inf")
|
||||
return graph_info
|
||||
|
Reference in New Issue
Block a user