mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[fx] Add common node in model linearize (#1542)
* [fx] Add common node into linearize * [fx] Add common node to solver
This commit is contained in:
parent
964123ae0f
commit
46c6cc79a9
@ -114,12 +114,57 @@ def _discretize(mem_unit, values):
|
|||||||
return [math.ceil(value / mem_unit) for value in values]
|
return [math.ceil(value / mem_unit) for value in values]
|
||||||
|
|
||||||
|
|
||||||
def _construct_chain(node_list: List[List[Node]], data: torch.Tensor, mem_unit: int) -> Chain:
|
def _compute_size(obj: torch.Tensor) -> int:
|
||||||
|
return obj.numel() * obj.element_size()
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_output_size(node: List[Node]) -> int:
|
||||||
|
"""Compute the output size of a node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (List[Node]): node, list of torch.fx.Node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: output size
|
||||||
|
"""
|
||||||
|
|
||||||
|
return node[-1].meta['tensor_meta'].numel * torch.tensor([],
|
||||||
|
dtype=node[-1].meta['tensor_meta'].dtype).element_size()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_inplace(node: Node) -> bool:
|
||||||
|
"""Get the inplace argument from torch.fx.Node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (Node): torch.fx.Node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: indicates whether this op is inplace
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_inplace = False
|
||||||
|
if node.op == "call_function":
|
||||||
|
is_inplace = node.kwargs.get("inplace", False)
|
||||||
|
elif node.op == "call_module":
|
||||||
|
is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False)
|
||||||
|
|
||||||
|
return is_inplace
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||||
|
|
||||||
fwd_time = []
|
fwd_time = []
|
||||||
bwd_time = []
|
bwd_time = []
|
||||||
xbar_sizes = [data.numel() * data.element_size()]
|
|
||||||
x_sizes = [data.numel() * data.element_size()]
|
if isinstance(data, torch.Tensor):
|
||||||
|
xbar_sizes = [_compute_size(data)]
|
||||||
|
x_sizes = [_compute_size(data)]
|
||||||
|
elif isinstance(data, list) or isinstance(data, tuple):
|
||||||
|
xbar_sizes = [_compute_size(obj) for obj in data]
|
||||||
|
x_sizes = [_compute_size(obj) for obj in data]
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
xbar_sizes = [_compute_size(obj) for obj in data.values()]
|
||||||
|
x_sizes = [_compute_size(obj) for obj in data.values()]
|
||||||
|
|
||||||
# currently we can't get the temp memory needed in fwd and bwd
|
# currently we can't get the temp memory needed in fwd and bwd
|
||||||
tmp_fwd = [0] * len(node_list)
|
tmp_fwd = [0] * len(node_list)
|
||||||
@ -129,16 +174,27 @@ def _construct_chain(node_list: List[List[Node]], data: torch.Tensor, mem_unit:
|
|||||||
fwd_time.append(0)
|
fwd_time.append(0)
|
||||||
bwd_time.append(0)
|
bwd_time.append(0)
|
||||||
xbar_sizes.append(0)
|
xbar_sizes.append(0)
|
||||||
x_sizes.append(node[-1].meta['tensor_meta'].numel *
|
x_sizes.append(_compute_output_size(node))
|
||||||
torch.tensor([], dtype=node[-1].meta['tensor_meta'].dtype).element_size())
|
|
||||||
|
_check_inplace_flag = 1
|
||||||
for n in node:
|
for n in node:
|
||||||
fwd_time[-1] += max(n.__flops__, 1)
|
fwd_time[-1] += max(n.__flops__, 1)
|
||||||
|
|
||||||
# currently we haven't patched the backward flops count
|
# currently we haven't patched the backward flops count
|
||||||
bwd_time[-1] += max(n.__flops__ * 2, 2)
|
bwd_time[-1] += max(n.__flops__ * 2, 2)
|
||||||
|
|
||||||
xbar_sizes[-1] += n.__activation__
|
xbar_sizes[-1] += n.__activation__
|
||||||
|
|
||||||
|
# we need to clear the xbar of previous node as there is
|
||||||
|
# one op in the current node that use the previous node's
|
||||||
|
# output but applies inplace operation on it
|
||||||
|
# NOTE: This process should be done only once as the previous
|
||||||
|
# node will only have one output
|
||||||
|
if _check_inplace_flag:
|
||||||
|
for par in n._input_nodes:
|
||||||
|
if par not in node and _get_inplace(n):
|
||||||
|
xbar_sizes[-2] -= x_sizes[-2]
|
||||||
|
_check_inplace_flag = 0
|
||||||
|
|
||||||
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
|
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
|
||||||
|
|
||||||
bwd_time.append(0)
|
bwd_time.append(0)
|
||||||
@ -186,20 +242,25 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
|
|||||||
ckpt_region.append(idx)
|
ckpt_region.append(idx)
|
||||||
|
|
||||||
|
|
||||||
def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> ColoGraphModule:
|
def solver_rotor(gm: ColoGraphModule,
|
||||||
|
data,
|
||||||
|
mem_limit: int,
|
||||||
|
mem_slots: int = 500,
|
||||||
|
cnode: List[str] = None) -> ColoGraphModule:
|
||||||
"""solver that automatically find activation checkpoint in rotor's manner
|
"""solver that automatically find activation checkpoint in rotor's manner
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gm (ColoGraphModule): ColoGraphModule generated by tracing model.
|
gm (ColoGraphModule): ColoGraphModule generated by tracing model.
|
||||||
data (torch.Tensor): input data.
|
data (torch.Tensor): input data.
|
||||||
mem_limit (int): memory budget in Byte.
|
mem_limit (int): memory budget in Byte.
|
||||||
mem_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
||||||
|
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_list = linearize(gm)
|
node_list = linearize(gm, cnode)
|
||||||
mem_unit = mem_limit // mem_slots
|
mem_unit = mem_limit // mem_slots
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
||||||
|
@ -2,11 +2,12 @@ from typing import List
|
|||||||
from torch.fx import GraphModule, Node
|
from torch.fx import GraphModule, Node
|
||||||
|
|
||||||
|
|
||||||
def linearize(gm: GraphModule) -> List[List[Node]]:
|
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
||||||
"""Linearizing the graph
|
"""Linearizing the graph
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gm (GraphModule): GraphModule derived by tracing
|
gm (GraphModule): GraphModule derived by tracing
|
||||||
|
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[Node]]: List of list, each inside list of Node presents
|
List[List[Node]]: List of list, each inside list of Node presents
|
||||||
@ -22,23 +23,39 @@ def linearize(gm: GraphModule) -> List[List[Node]]:
|
|||||||
|
|
||||||
return not sum([v for _, v in deps.items()])
|
return not sum([v for _, v in deps.items()])
|
||||||
|
|
||||||
|
# make sure that item in cnode is valid
|
||||||
|
if cnode:
|
||||||
|
for name in cnode:
|
||||||
|
try:
|
||||||
|
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
|
||||||
|
f"common node {name} is not an input of the model"
|
||||||
|
except StopIteration:
|
||||||
|
raise ValueError(f"common node name {name} not in graph")
|
||||||
|
|
||||||
|
else:
|
||||||
|
cnode = []
|
||||||
|
|
||||||
deps = {}
|
deps = {}
|
||||||
linearized_nodes = []
|
linearized_nodes = []
|
||||||
region = []
|
region = []
|
||||||
|
|
||||||
for n in gm.graph.nodes:
|
for n in gm.graph.nodes:
|
||||||
for n_par in n._input_nodes:
|
if n.op != "placeholder" and n.op != "output":
|
||||||
deps[n_par] -= 1
|
for n_par in n._input_nodes:
|
||||||
region.append(n)
|
if n_par.op != "placeholder" and n_par.name not in cnode:
|
||||||
|
deps[n_par] -= 1
|
||||||
|
region.append(n)
|
||||||
|
|
||||||
# if the node could free all dependencies in graph
|
# if the node could free all dependencies in graph
|
||||||
# we could begin a new node
|
# we could begin a new node
|
||||||
if _is_sink():
|
if _is_sink():
|
||||||
linearized_nodes.append(region)
|
linearized_nodes.append(region)
|
||||||
region = []
|
region = []
|
||||||
|
|
||||||
deps[n] = len(n.users)
|
# propagate common node attr if possible
|
||||||
|
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]):
|
||||||
|
cnode.append(n.name)
|
||||||
|
else:
|
||||||
|
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||||
|
|
||||||
# Remove input
|
|
||||||
linearized_nodes = linearized_nodes[1:-1]
|
|
||||||
return linearized_nodes
|
return linearized_nodes
|
||||||
|
Loading…
Reference in New Issue
Block a user