mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-02 20:35:29 +00:00
[NFC] fix typo with colossalai/auto_parallel/tensor_shard (#3742)
* fix typo applications/ and colossalai/ date 5.11 * fix typo colossalai/
This commit is contained in:
parent
7386c6669d
commit
1baeb39c72
@ -75,7 +75,7 @@ class NodeHandler(ABC):
|
|||||||
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
|
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
|
||||||
]
|
]
|
||||||
|
|
||||||
# create data structrure to store costs
|
# create data structure to store costs
|
||||||
if node not in resharding_costs:
|
if node not in resharding_costs:
|
||||||
resharding_costs[node] = []
|
resharding_costs[node] = []
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||||||
To keep the math consistency, there are two way to do BatchNorm if the input
|
To keep the math consistency, there are two way to do BatchNorm if the input
|
||||||
shards on batch dimension:
|
shards on batch dimension:
|
||||||
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
|
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
|
||||||
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
|
2. We do the SyncBatchNorm on the each input partition separately, the SyncBN op will help
|
||||||
us to keep the computing correctness.
|
us to keep the computing correctness.
|
||||||
In this generator, both methods will be considered.
|
In this generator, both methods will be considered.
|
||||||
"""
|
"""
|
||||||
@ -212,7 +212,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||||||
|
|
||||||
# set communication action
|
# set communication action
|
||||||
# For SyncBN case, we don't need to do communication for weight and bias.
|
# For SyncBN case, we don't need to do communication for weight and bias.
|
||||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
|
||||||
# to SyncBN operation instead of inserting a communication node.
|
# to SyncBN operation instead of inserting a communication node.
|
||||||
output_comm_action = self.get_communication_action(
|
output_comm_action = self.get_communication_action(
|
||||||
sharding_spec=sharding_spec_mapping["output"],
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
@ -250,7 +250,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||||||
|
|
||||||
# set communication action
|
# set communication action
|
||||||
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
|
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
|
||||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
|
||||||
# to SyncBN operation instead of inserting a communication node.
|
# to SyncBN operation instead of inserting a communication node.
|
||||||
output_comm_action = self.get_communication_action(
|
output_comm_action = self.get_communication_action(
|
||||||
sharding_spec=sharding_spec_mapping["output"],
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
@ -298,7 +298,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||||||
|
|
||||||
# set communication action
|
# set communication action
|
||||||
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
|
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
|
||||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
|
||||||
# to SyncBN operation instead of inserting a communication node.
|
# to SyncBN operation instead of inserting a communication node.
|
||||||
output_comm_action = self.get_communication_action(
|
output_comm_action = self.get_communication_action(
|
||||||
sharding_spec=sharding_spec_mapping["output"],
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
|
@ -51,7 +51,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
|
|||||||
|
|
||||||
# compute fwd memory cost in bytes
|
# compute fwd memory cost in bytes
|
||||||
# as the elementwise ops are not memory-intensive
|
# as the elementwise ops are not memory-intensive
|
||||||
# we approximate the fwd memroy cost to be the output
|
# we approximate the fwd memory cost to be the output
|
||||||
# and the backward memory cost to be grad of input and other
|
# and the backward memory cost to be grad of input and other
|
||||||
input_bytes = self._compute_size_in_bytes(strategy, 'input')
|
input_bytes = self._compute_size_in_bytes(strategy, 'input')
|
||||||
other_bytes = self._compute_size_in_bytes(strategy, 'other')
|
other_bytes = self._compute_size_in_bytes(strategy, 'other')
|
||||||
|
@ -225,7 +225,7 @@ class StrategyGenerator(ABC):
|
|||||||
if isinstance(meta_data, torch.Tensor):
|
if isinstance(meta_data, torch.Tensor):
|
||||||
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
|
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
|
||||||
else:
|
else:
|
||||||
# if meta_data is not a tensor, we count the memroy as 0
|
# if meta_data is not a tensor, we count the memory as 0
|
||||||
element_bytes = 0
|
element_bytes = 0
|
||||||
total_bytes += element_bytes
|
total_bytes += element_bytes
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ class StrategyGenerator(ABC):
|
|||||||
if isinstance(op_data.data, torch.Tensor):
|
if isinstance(op_data.data, torch.Tensor):
|
||||||
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
|
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
|
||||||
else:
|
else:
|
||||||
# if op_data.data is not a tensor, we count the memroy as 0
|
# if op_data.data is not a tensor, we count the memory as 0
|
||||||
total_bytes = 0
|
total_bytes = 0
|
||||||
|
|
||||||
return total_bytes
|
return total_bytes
|
||||||
|
@ -9,7 +9,7 @@ class CostGraph:
|
|||||||
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||||
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||||
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
|
element-wise operators, transpose, and reduction, into their following nodes. The merging information will
|
||||||
be given by the StrategiesVector depending on the type of target node and following nodes.
|
be given by the StrategiesVector depending on the type of target node and following nodes.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
@ -90,7 +90,7 @@ class CostGraph:
|
|||||||
if self.simplify and strategies_vector.check_merge():
|
if self.simplify and strategies_vector.check_merge():
|
||||||
for followed_node in strategies_vector.predecessor_nodes:
|
for followed_node in strategies_vector.predecessor_nodes:
|
||||||
# we only merge node pairs which src node has a tensor element inside.
|
# we only merge node pairs which src node has a tensor element inside.
|
||||||
# This is necessay because the node without a tensor element inside will not
|
# This is necessary because the node without a tensor element inside will not
|
||||||
# be assigned any strategy.
|
# be assigned any strategy.
|
||||||
if _check_tensor_in_node(followed_node._meta_data):
|
if _check_tensor_in_node(followed_node._meta_data):
|
||||||
self.merge_pair.append((followed_node, dst_node))
|
self.merge_pair.append((followed_node, dst_node))
|
||||||
|
@ -83,7 +83,7 @@ class GraphAnalyser:
|
|||||||
|
|
||||||
def liveness_analysis(self) -> List[LiveStage]:
|
def liveness_analysis(self) -> List[LiveStage]:
|
||||||
"""
|
"""
|
||||||
Analyse the graph to obtain the variable liveness information. This function returns
|
Analyses the graph to obtain the variable liveness information. This function returns
|
||||||
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
|
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
|
||||||
"""
|
"""
|
||||||
compute_nodes = self.graph.nodes
|
compute_nodes = self.graph.nodes
|
||||||
@ -91,7 +91,7 @@ class GraphAnalyser:
|
|||||||
|
|
||||||
# checked: record all variables created since the first stage
|
# checked: record all variables created since the first stage
|
||||||
# all: record the live variables only exist until the current stage.
|
# all: record the live variables only exist until the current stage.
|
||||||
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
|
# this can be different from the `checked list`` as some variables may be destroyed prior to this stage.
|
||||||
# unique: record the unique live variables only exist until the current stage.
|
# unique: record the unique live variables only exist until the current stage.
|
||||||
# this is different from `all list` as some variables are duplicated.
|
# this is different from `all list` as some variables are duplicated.
|
||||||
checked_variables = LiveVariableVector()
|
checked_variables = LiveVariableVector()
|
||||||
@ -103,7 +103,7 @@ class GraphAnalyser:
|
|||||||
# find new living variables #
|
# find new living variables #
|
||||||
#############################
|
#############################
|
||||||
# detect whether the current op is an in-place op
|
# detect whether the current op is an in-place op
|
||||||
# if it is an in-place op, we would deem it as a duplciate var
|
# if it is an in-place op, we would deem it as a duplicate var
|
||||||
is_inplace = False
|
is_inplace = False
|
||||||
if node.op == 'call_function':
|
if node.op == 'call_function':
|
||||||
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
|
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
|
||||||
|
@ -44,7 +44,7 @@ class Solver:
|
|||||||
graph: The computing graph to be optimized.
|
graph: The computing graph to be optimized.
|
||||||
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
||||||
cost_graph: A graph data structure to simplify the edge cost graph.
|
cost_graph: A graph data structure to simplify the edge cost graph.
|
||||||
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
graph_analyser: graph_analyser will analyses the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
||||||
memory_budget: Memory constraint for the solution.
|
memory_budget: Memory constraint for the solution.
|
||||||
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||||
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||||
|
@ -33,7 +33,7 @@ def run_on_environment_flag(name: str):
|
|||||||
assert isinstance(name, str)
|
assert isinstance(name, str)
|
||||||
flag = os.environ.get(name.upper(), '0')
|
flag = os.environ.get(name.upper(), '0')
|
||||||
|
|
||||||
reason = f'Environment varialbe {name} is {flag}'
|
reason = f'Environment variable {name} is {flag}'
|
||||||
if flag == '1':
|
if flag == '1':
|
||||||
return pytest.mark.skipif(False, reason=reason)
|
return pytest.mark.skipif(False, reason=reason)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user