mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
[autoparallel] integrate 2-stage solver (#3476)
* integrate 2-stage solver * [autoparallel] integrate 2-stage solver * polish
This commit is contained in:
parent
190a6ea9c2
commit
5458da5c3c
@ -94,6 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
|||||||
current_region = None
|
current_region = None
|
||||||
|
|
||||||
for idx, node in enumerate(node_list):
|
for idx, node in enumerate(node_list):
|
||||||
|
if 'info' not in node.meta:
|
||||||
|
continue
|
||||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||||
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
||||||
|
|
||||||
|
@ -5,12 +5,12 @@ from typing import Any, List
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||||
from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
||||||
runtime_apply,
|
runtime_apply,
|
||||||
runtime_apply_for_iterable_object,
|
runtime_apply_for_iterable_object,
|
||||||
runtime_comm_spec_apply,
|
runtime_comm_spec_apply,
|
||||||
)
|
)
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
|
||||||
|
|
||||||
__all___ = ['CheckpointSolverBase']
|
__all___ = ['CheckpointSolverBase']
|
||||||
|
|
||||||
|
@ -369,7 +369,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||||||
in_ckpt = False
|
in_ckpt = False
|
||||||
for node_idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for n in node_list[node_idx]:
|
for n in node_list[node_idx]:
|
||||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
n.meta['info'].activation_checkpoint = [ckpt_idx]
|
||||||
|
|
||||||
ckpt_idx += 1
|
ckpt_idx += 1
|
||||||
ckpt_region = []
|
ckpt_region = []
|
||||||
@ -377,7 +377,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||||||
elif isinstance(op, ForwardCheck):
|
elif isinstance(op, ForwardCheck):
|
||||||
for node_idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for n in node_list[node_idx]:
|
for n in node_list[node_idx]:
|
||||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
n.meta['info'].activation_checkpoint = [ckpt_idx]
|
||||||
|
|
||||||
ckpt_idx += 1
|
ckpt_idx += 1
|
||||||
ckpt_region = [idx]
|
ckpt_region = [idx]
|
||||||
@ -397,7 +397,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||||||
elif isinstance(op, ForwardEnable):
|
elif isinstance(op, ForwardEnable):
|
||||||
for node_idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for n in node_list[node_idx]:
|
for n in node_list[node_idx]:
|
||||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
n.meta['info'].activation_checkpoint.append(ckpt_idx)
|
||||||
|
|
||||||
ckpt_idx += 1
|
ckpt_idx += 1
|
||||||
ckpt_region = []
|
ckpt_region = []
|
||||||
@ -405,7 +405,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||||||
elif isinstance(op, ForwardCheck):
|
elif isinstance(op, ForwardCheck):
|
||||||
for node_idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for n in node_list[node_idx]:
|
for n in node_list[node_idx]:
|
||||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
n.meta['info'].activation_checkpoint.append(ckpt_idx)
|
||||||
|
|
||||||
ckpt_idx += 1
|
ckpt_idx += 1
|
||||||
ckpt_region = [op.index]
|
ckpt_region = [op.index]
|
||||||
@ -413,7 +413,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||||||
elif isinstance(op, Backward):
|
elif isinstance(op, Backward):
|
||||||
for node_idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for n in node_list[node_idx]:
|
for n in node_list[node_idx]:
|
||||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
n.meta['info'].activation_checkpoint.append(ckpt_idx)
|
||||||
|
|
||||||
in_recompute = False
|
in_recompute = False
|
||||||
|
|
||||||
@ -433,7 +433,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||||
for (start_idx, end_idx) in ckpt_regions:
|
for (start_idx, end_idx) in ckpt_regions:
|
||||||
nested_length = max(
|
nested_length = max(
|
||||||
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
|
len(op_list[idx].meta['info'].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
||||||
for idx in range(start_idx, end_idx + 1):
|
for idx in range(start_idx, end_idx + 1):
|
||||||
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
|
op_list[idx].meta['info'].activation_checkpoint += [None] * (
|
||||||
len(op_list[idx].meta['activation_checkpoint']))
|
nested_length - len(op_list[idx].meta['info'].activation_checkpoint))
|
||||||
|
@ -68,7 +68,7 @@ class MetaInfoProp:
|
|||||||
graph_info = GraphInfo()
|
graph_info = GraphInfo()
|
||||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||||
graph_info.fwd_out = list(out) if out[0] is not None else []
|
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||||
node.meta = {**asdict(graph_info)}
|
node.meta.update(asdict(graph_info))
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def get_attr_handler(self, node: Node) -> None:
|
def get_attr_handler(self, node: Node) -> None:
|
||||||
@ -76,7 +76,7 @@ class MetaInfoProp:
|
|||||||
Handle the get_attr node.
|
Handle the get_attr node.
|
||||||
"""
|
"""
|
||||||
graph_info = GraphInfo()
|
graph_info = GraphInfo()
|
||||||
node.meta = {**asdict(graph_info)}
|
node.meta.update(asdict(graph_info))
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def output_handler(self, node: Node) -> None:
|
def output_handler(self, node: Node) -> None:
|
||||||
@ -89,15 +89,36 @@ class MetaInfoProp:
|
|||||||
if par.meta:
|
if par.meta:
|
||||||
output_tensors += par.meta["fwd_out"]
|
output_tensors += par.meta["fwd_out"]
|
||||||
graph_info.fwd_in = output_tensors
|
graph_info.fwd_in = output_tensors
|
||||||
node.meta = {**asdict(graph_info)}
|
node.meta.update(asdict(graph_info))
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def node_handler(self, node: Node) -> None:
|
def node_handler(self, node: Node) -> None:
|
||||||
"""
|
"""
|
||||||
Handle other kind of nodes
|
Handle other kind of nodes
|
||||||
"""
|
"""
|
||||||
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
|
||||||
graph_info = GraphInfo()
|
graph_info = GraphInfo()
|
||||||
|
if not hasattr(node, 'best_strategy_info'):
|
||||||
|
# attach them to graph_info
|
||||||
|
graph_info.fwd_in = []
|
||||||
|
graph_info.fwd_tmp = []
|
||||||
|
graph_info.fwd_out = []
|
||||||
|
|
||||||
|
# fetch other memory informations
|
||||||
|
graph_info.fwd_mem_tmp = 10
|
||||||
|
graph_info.fwd_mem_out = 10
|
||||||
|
graph_info.bwd_mem_tmp = 10
|
||||||
|
graph_info.bwd_mem_out = 10
|
||||||
|
|
||||||
|
# fetch flop information
|
||||||
|
# here we use fwd_time and bwd_time to deal with the case that
|
||||||
|
# communication cost is a float
|
||||||
|
graph_info.fwd_time = 10
|
||||||
|
graph_info.bwd_time = 10
|
||||||
|
node.meta.update(asdict(graph_info))
|
||||||
|
# print(node.name, [isinstance(arg, torch.Tensor) for arg in node.args], isinstance(node._meta_data, torch.Tensor))
|
||||||
|
return
|
||||||
|
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||||
|
|
||||||
meta_info = node.best_strategy_info
|
meta_info = node.best_strategy_info
|
||||||
meta_info: ShardMetaInfo
|
meta_info: ShardMetaInfo
|
||||||
|
|
||||||
@ -124,6 +145,8 @@ class MetaInfoProp:
|
|||||||
for par in node._input_nodes:
|
for par in node._input_nodes:
|
||||||
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||||
for tensor in par.meta.get("fwd_out", []):
|
for tensor in par.meta.get("fwd_out", []):
|
||||||
|
if not isinstance(tensor, torch.Tensor):
|
||||||
|
continue
|
||||||
tensor: torch.Tensor
|
tensor: torch.Tensor
|
||||||
target_input_tensor = next(
|
target_input_tensor = next(
|
||||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||||
@ -161,5 +184,4 @@ class MetaInfoProp:
|
|||||||
compute_cost = meta_info.compute_cost
|
compute_cost = meta_info.compute_cost
|
||||||
graph_info.fwd_time = compute_cost.fwd
|
graph_info.fwd_time = compute_cost.fwd
|
||||||
graph_info.bwd_time = compute_cost.bwd
|
graph_info.bwd_time = compute_cost.bwd
|
||||||
|
node.meta.update(asdict(graph_info))
|
||||||
node.meta = {**asdict(graph_info)}
|
|
||||||
|
@ -157,9 +157,22 @@ class NodeHandler(ABC):
|
|||||||
Register different sharding strategies for the current node.
|
Register different sharding strategies for the current node.
|
||||||
"""
|
"""
|
||||||
strategy_generators = self.get_strategy_generator()
|
strategy_generators = self.get_strategy_generator()
|
||||||
|
strategies_info = []
|
||||||
for generator in strategy_generators:
|
for generator in strategy_generators:
|
||||||
strategies = generator.generate()
|
strategies = generator.generate()
|
||||||
|
|
||||||
|
for strategy in strategies:
|
||||||
|
shard_metainfo = ShardMetaInfo()
|
||||||
|
shard_metainfo.compute_cost = strategy.compute_cost
|
||||||
|
shard_metainfo.memory_cost = strategy.memory_cost
|
||||||
|
shard_metainfo.fwd_in = []
|
||||||
|
if isinstance(self.node._meta_data, torch.Tensor):
|
||||||
|
shard_metainfo.fwd_out = [self.node._meta_data]
|
||||||
|
else:
|
||||||
|
shard_metainfo.fwd_out = self.node._meta_data
|
||||||
|
shard_metainfo.fwd_buffer = []
|
||||||
|
strategies_info.append(shard_metainfo)
|
||||||
|
|
||||||
# postprocess a strategy
|
# postprocess a strategy
|
||||||
# postprocess can produce one strategy or multiple strategies
|
# postprocess can produce one strategy or multiple strategies
|
||||||
post_processed_strategies_map = map(self.post_process, strategies)
|
post_processed_strategies_map = map(self.post_process, strategies)
|
||||||
|
@ -102,7 +102,6 @@ class StrategiesConstructor:
|
|||||||
|
|
||||||
if _check_no_strategy_for_node(node):
|
if _check_no_strategy_for_node(node):
|
||||||
self.no_strategy_nodes.append(node)
|
self.no_strategy_nodes.append(node)
|
||||||
pass
|
|
||||||
|
|
||||||
# placeholder node
|
# placeholder node
|
||||||
elif node.op == 'placeholder':
|
elif node.op == 'placeholder':
|
||||||
|
@ -0,0 +1,174 @@
|
|||||||
|
from time import time
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||||
|
NON_CODEGEN = False
|
||||||
|
except:
|
||||||
|
NON_CODEGEN = True
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.auto_parallel.checkpoint.ckpt_solver_rotor import CheckpointSolverRotor
|
||||||
|
from colossalai.auto_parallel.passes.comm_metainfo_pass import comm_metainfo_pass
|
||||||
|
from colossalai.auto_parallel.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.auto_parallel.tensor_shard.initialize import (
|
||||||
|
ModuleWrapper,
|
||||||
|
build_strategy_constructor,
|
||||||
|
solve_solution,
|
||||||
|
transform_to_sharded_model,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.initialize import launch, launch_from_torch
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||||
|
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||||
|
|
||||||
|
BATCH_SIZE = 16
|
||||||
|
SEQ_LENGTH = 1024
|
||||||
|
HIDDEN_DIM = 2048
|
||||||
|
NUM_HEADS = 16
|
||||||
|
NUM_LAYERS = 2
|
||||||
|
VOCAB_SIZE = 50257
|
||||||
|
NUM_STEPS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_mem():
|
||||||
|
return psutil.Process().memory_info().rss / 1024**2
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_mem():
|
||||||
|
return torch.cuda.memory_allocated() / 1024**2
|
||||||
|
|
||||||
|
|
||||||
|
def get_mem_info(prefix=''):
|
||||||
|
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_mlp(model_cls):
|
||||||
|
"""
|
||||||
|
Generate random data for resnet benchmarking
|
||||||
|
"""
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32)
|
||||||
|
|
||||||
|
if model_cls == GPT2MLP:
|
||||||
|
input_sample = (hidden_states.to('cuda'),)
|
||||||
|
meta_input_sample = {
|
||||||
|
'hidden_states': hidden_states.to('meta'),
|
||||||
|
}
|
||||||
|
elif model_cls in (GPT2Attention, GPT2Block):
|
||||||
|
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
input_sample = (
|
||||||
|
hidden_states.to('cuda'),
|
||||||
|
attention_mask.to('cuda'),
|
||||||
|
)
|
||||||
|
meta_input_sample = {
|
||||||
|
'hidden_states': hidden_states.to('meta'),
|
||||||
|
'attention_mask': attention_mask.to('meta'),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
input_sample = (
|
||||||
|
input_ids.to('cuda'),
|
||||||
|
attention_mask.to('cuda'),
|
||||||
|
)
|
||||||
|
meta_input_sample = {
|
||||||
|
'input_ids': input_ids.to('meta'),
|
||||||
|
'attention_mask': attention_mask.to('meta'),
|
||||||
|
}
|
||||||
|
return input_sample, meta_input_sample
|
||||||
|
|
||||||
|
|
||||||
|
def check_2stage_solver_on_gpt(rank, world_size, port, model_cls):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||||
|
logger = get_dist_logger()
|
||||||
|
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
|
||||||
|
if model_cls == GPT2MLP:
|
||||||
|
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
|
||||||
|
else:
|
||||||
|
model = model_cls(config=config).to('cuda')
|
||||||
|
|
||||||
|
input_sample, meta_input_sample = data_gen_mlp(model_cls)
|
||||||
|
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
|
||||||
|
tracer = ColoTracer(bias_addition_split=True)
|
||||||
|
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_input_sample)
|
||||||
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
|
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||||
|
shape_prop_pass(gm, *meta_input_sample.values())
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
|
||||||
|
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
|
||||||
|
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh,
|
||||||
|
strategies_constructor)
|
||||||
|
comm_metainfo_pass(gm, *sharding_spec_dicts)
|
||||||
|
|
||||||
|
MetaInfoProp(gm).run()
|
||||||
|
|
||||||
|
gm = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||||
|
ckpt_solver = CheckpointSolverRotor(gm.module.graph, 8 * 1024**3)
|
||||||
|
gm.module.graph = ckpt_solver.solve()
|
||||||
|
ckpt_solver.print_sequence()
|
||||||
|
gm.module.recompile()
|
||||||
|
print(gm.module)
|
||||||
|
logger.info("*******************strategy selected*******************", ranks=[0])
|
||||||
|
strategies_list = solution
|
||||||
|
|
||||||
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||||
|
for index, node in enumerate(nodes):
|
||||||
|
logger.info(node.name, ranks=[0])
|
||||||
|
logger.info(node.strategies_vector[strategies_list[index]].name, ranks=[0])
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
|
||||||
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
gm.train()
|
||||||
|
|
||||||
|
for n in range(10):
|
||||||
|
# we just use randomly generated data here
|
||||||
|
input_sample, _ = data_gen_mlp(model_cls)
|
||||||
|
mem_stamp0 = torch.cuda.memory_allocated(device='cuda:0') / 1024**2
|
||||||
|
optimizer.zero_grad()
|
||||||
|
start = time()
|
||||||
|
loss = gm(*input_sample)
|
||||||
|
loss.sum().backward()
|
||||||
|
optimizer.step()
|
||||||
|
# prof.step()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
step_time = time() - start
|
||||||
|
logger.info(f"===============Round {n}===============", ranks=[0])
|
||||||
|
logger.info(
|
||||||
|
f"Peak Memory: {torch.cuda.max_memory_allocated(device='cuda:0') / 1024**2 - mem_stamp0} MB, Step Time: {step_time:.3f}s",
|
||||||
|
ranks=[0])
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.skipif(NON_CODEGEN, reason='codegen is not available')
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@parameterize('model_cls', [GPT2MLP, GPT2Attention, GPT2Block, GPT2Model])
|
||||||
|
def test_2stage_solver_on_gpt(model_cls):
|
||||||
|
spawn(
|
||||||
|
check_2stage_solver_on_gpt,
|
||||||
|
4,
|
||||||
|
model_cls=model_cls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_2stage_solver_on_gpt()
|
@ -0,0 +1,138 @@
|
|||||||
|
from time import time
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torchvision.models as tm
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||||
|
NON_CODEGEN = False
|
||||||
|
except:
|
||||||
|
NON_CODEGEN = True
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.auto_parallel.checkpoint.ckpt_solver_rotor import CheckpointSolverRotor
|
||||||
|
from colossalai.auto_parallel.passes.comm_metainfo_pass import comm_metainfo_pass
|
||||||
|
from colossalai.auto_parallel.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.auto_parallel.tensor_shard.initialize import (
|
||||||
|
ModuleWrapper,
|
||||||
|
build_strategy_constructor,
|
||||||
|
solve_solution,
|
||||||
|
transform_to_sharded_model,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.initialize import launch, launch_from_torch
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
BATCH_SIZE = 256
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_mem():
|
||||||
|
return psutil.Process().memory_info().rss / 1024**2
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_mem():
|
||||||
|
return torch.cuda.memory_allocated() / 1024**2
|
||||||
|
|
||||||
|
|
||||||
|
def get_mem_info(prefix=''):
|
||||||
|
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_resnet(batch_size, shape):
|
||||||
|
"""
|
||||||
|
Generate random data for resnet benchmarking
|
||||||
|
"""
|
||||||
|
data = torch.empty(batch_size, *shape, device=torch.cuda.current_device())
|
||||||
|
label = torch.empty(batch_size, dtype=torch.long, device=torch.cuda.current_device()).random_(1000)
|
||||||
|
return data, label
|
||||||
|
|
||||||
|
|
||||||
|
def check_2stage_solver_on_resnet(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
logger = get_dist_logger()
|
||||||
|
model = tm.resnet50().cuda()
|
||||||
|
|
||||||
|
meta_input_sample = {
|
||||||
|
'x': torch.randn(BATCH_SIZE * 4, 3, 224, 224, device='meta'),
|
||||||
|
}
|
||||||
|
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
|
||||||
|
tracer = ColoTracer(bias_addition_split=True, trace_act_ckpt=True)
|
||||||
|
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_input_sample)
|
||||||
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
|
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||||
|
shape_prop_pass(gm, *meta_input_sample.values())
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
|
||||||
|
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
|
||||||
|
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh,
|
||||||
|
strategies_constructor)
|
||||||
|
comm_metainfo_pass(gm, *sharding_spec_dicts)
|
||||||
|
MetaInfoProp(gm).run()
|
||||||
|
gm = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||||
|
|
||||||
|
ckpt_solver = CheckpointSolverRotor(gm.module.graph, 8 * 1024**3)
|
||||||
|
gm.module.graph = ckpt_solver.solve()
|
||||||
|
ckpt_solver.print_sequence()
|
||||||
|
|
||||||
|
logger.info("*******************strategy selected*******************", ranks=[0])
|
||||||
|
strategies_list = solution
|
||||||
|
|
||||||
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||||
|
for index, node in enumerate(nodes):
|
||||||
|
logger.info(node.name, ranks=[0])
|
||||||
|
logger.info(node.strategies_vector[strategies_list[index]].name, ranks=[0])
|
||||||
|
|
||||||
|
# build criterion
|
||||||
|
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
|
||||||
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
model.train()
|
||||||
|
for n in range(10):
|
||||||
|
# we just use randomly generated data here
|
||||||
|
data, label = data_gen_resnet(BATCH_SIZE * 4, (3, 224, 224))
|
||||||
|
mem_stamp0 = torch.cuda.memory_allocated(device='cuda:0') / 1024**2
|
||||||
|
optimizer.zero_grad()
|
||||||
|
start = time()
|
||||||
|
outputs = gm(data)
|
||||||
|
loss = criterion(outputs, label)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
step_time = time() - start
|
||||||
|
logger.info(f"===============Round {n}===============", ranks=[0])
|
||||||
|
logger.info(
|
||||||
|
f"Peak Memory: {torch.cuda.max_memory_allocated(device='cuda:0') / 1024**2 - mem_stamp0} MB, Step Time: {step_time:.3f}s",
|
||||||
|
ranks=[0])
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.skipif(NON_CODEGEN, reason='codegen is not available')
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_2stage_solver_on_resnet():
|
||||||
|
spawn(
|
||||||
|
check_2stage_solver_on_resnet,
|
||||||
|
4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_2stage_solver_on_resnet()
|
@ -162,7 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
|
|||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_linear_handler():
|
def test_linear_handler():
|
||||||
spawn(check_linear_module_handler)
|
spawn(check_linear_module_handler, 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -151,7 +151,7 @@ def check_linear_module_handler(rank, world_size, port, bias):
|
|||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_linear_handler(bias=True):
|
def test_linear_handler(bias=True):
|
||||||
spawn(check_linear_module_handler, bias=bias)
|
spawn(check_linear_module_handler, 4, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -100,7 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
|
|||||||
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
|
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
|
||||||
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
|
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
|
||||||
def test_getitem_from_tensor_handler(getitem_index):
|
def test_getitem_from_tensor_handler(getitem_index):
|
||||||
spawn(check_getitem_from_tensor_handler, 4)
|
spawn(check_getitem_from_tensor_handler, 4, getitem_index=getitem_index)
|
||||||
|
|
||||||
|
|
||||||
class GetItemFromTupleModel(nn.Module):
|
class GetItemFromTupleModel(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user