[autoparallel] integrate 2-stage solver (#3476)

* integrate 2-stage solver

* [autoparallel] integrate 2-stage solver

* polish
This commit is contained in:
YuliangLiu0306 2023-07-05 12:07:07 +08:00 committed by GitHub
parent 190a6ea9c2
commit 5458da5c3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 367 additions and 19 deletions

View File

@ -94,6 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region = None
for idx, node in enumerate(node_list):
if 'info' not in node.meta:
continue
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]

View File

@ -5,12 +5,12 @@ from typing import Any, List
import torch
from torch.fx import Graph, Node
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
from colossalai.auto_parallel.passes.runtime_apply_pass import (
runtime_apply,
runtime_apply_for_iterable_object,
runtime_comm_spec_apply,
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
__all___ = ['CheckpointSolverBase']

View File

@ -369,7 +369,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
n.meta['info'].activation_checkpoint = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
@ -377,7 +377,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
n.meta['info'].activation_checkpoint = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
@ -397,7 +397,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta['info'].activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
@ -405,7 +405,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta['info'].activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
@ -413,7 +413,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta['info'].activation_checkpoint.append(ckpt_idx)
in_recompute = False
@ -433,7 +433,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
len(op_list[idx].meta['info'].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
len(op_list[idx].meta['activation_checkpoint']))
op_list[idx].meta['info'].activation_checkpoint += [None] * (
nested_length - len(op_list[idx].meta['info'].activation_checkpoint))

View File

@ -68,7 +68,7 @@ class MetaInfoProp:
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
node.meta.update(asdict(graph_info))
@compatibility(is_backward_compatible=False)
def get_attr_handler(self, node: Node) -> None:
@ -76,7 +76,7 @@ class MetaInfoProp:
Handle the get_attr node.
"""
graph_info = GraphInfo()
node.meta = {**asdict(graph_info)}
node.meta.update(asdict(graph_info))
@compatibility(is_backward_compatible=False)
def output_handler(self, node: Node) -> None:
@ -89,15 +89,36 @@ class MetaInfoProp:
if par.meta:
output_tensors += par.meta["fwd_out"]
graph_info.fwd_in = output_tensors
node.meta = {**asdict(graph_info)}
node.meta.update(asdict(graph_info))
@compatibility(is_backward_compatible=False)
def node_handler(self, node: Node) -> None:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
if not hasattr(node, 'best_strategy_info'):
# attach them to graph_info
graph_info.fwd_in = []
graph_info.fwd_tmp = []
graph_info.fwd_out = []
# fetch other memory informations
graph_info.fwd_mem_tmp = 10
graph_info.fwd_mem_out = 10
graph_info.bwd_mem_tmp = 10
graph_info.bwd_mem_out = 10
# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
graph_info.fwd_time = 10
graph_info.bwd_time = 10
node.meta.update(asdict(graph_info))
# print(node.name, [isinstance(arg, torch.Tensor) for arg in node.args], isinstance(node._meta_data, torch.Tensor))
return
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
@ -124,6 +145,8 @@ class MetaInfoProp:
for par in node._input_nodes:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta.get("fwd_out", []):
if not isinstance(tensor, torch.Tensor):
continue
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
@ -161,5 +184,4 @@ class MetaInfoProp:
compute_cost = meta_info.compute_cost
graph_info.fwd_time = compute_cost.fwd
graph_info.bwd_time = compute_cost.bwd
node.meta = {**asdict(graph_info)}
node.meta.update(asdict(graph_info))

View File

@ -157,9 +157,22 @@ class NodeHandler(ABC):
Register different sharding strategies for the current node.
"""
strategy_generators = self.get_strategy_generator()
strategies_info = []
for generator in strategy_generators:
strategies = generator.generate()
for strategy in strategies:
shard_metainfo = ShardMetaInfo()
shard_metainfo.compute_cost = strategy.compute_cost
shard_metainfo.memory_cost = strategy.memory_cost
shard_metainfo.fwd_in = []
if isinstance(self.node._meta_data, torch.Tensor):
shard_metainfo.fwd_out = [self.node._meta_data]
else:
shard_metainfo.fwd_out = self.node._meta_data
shard_metainfo.fwd_buffer = []
strategies_info.append(shard_metainfo)
# postprocess a strategy
# postprocess can produce one strategy or multiple strategies
post_processed_strategies_map = map(self.post_process, strategies)

View File

@ -102,7 +102,6 @@ class StrategiesConstructor:
if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node)
pass
# placeholder node
elif node.op == 'placeholder':

View File

@ -0,0 +1,174 @@
from time import time
import psutil
import pytest
import torch
import transformers
try:
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
NON_CODEGEN = False
except:
NON_CODEGEN = True
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.checkpoint.ckpt_solver_rotor import CheckpointSolverRotor
from colossalai.auto_parallel.passes.comm_metainfo_pass import comm_metainfo_pass
from colossalai.auto_parallel.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.tensor_shard.initialize import (
ModuleWrapper,
build_strategy_constructor,
solve_solution,
transform_to_sharded_model,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch, launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 16
SEQ_LENGTH = 1024
HIDDEN_DIM = 2048
NUM_HEADS = 16
NUM_LAYERS = 2
VOCAB_SIZE = 50257
NUM_STEPS = 10
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def data_gen_mlp(model_cls):
"""
Generate random data for resnet benchmarking
"""
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32)
if model_cls == GPT2MLP:
input_sample = (hidden_states.to('cuda'),)
meta_input_sample = {
'hidden_states': hidden_states.to('meta'),
}
elif model_cls in (GPT2Attention, GPT2Block):
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
input_sample = (
hidden_states.to('cuda'),
attention_mask.to('cuda'),
)
meta_input_sample = {
'hidden_states': hidden_states.to('meta'),
'attention_mask': attention_mask.to('meta'),
}
else:
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
input_sample = (
input_ids.to('cuda'),
attention_mask.to('cuda'),
)
meta_input_sample = {
'input_ids': input_ids.to('meta'),
'attention_mask': attention_mask.to('meta'),
}
return input_sample, meta_input_sample
def check_2stage_solver_on_gpt(rank, world_size, port, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
logger = get_dist_logger()
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
else:
model = model_cls(config=config).to('cuda')
input_sample, meta_input_sample = data_gen_mlp(model_cls)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_input_sample)
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
shape_prop_pass(gm, *meta_input_sample.values())
gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh,
strategies_constructor)
comm_metainfo_pass(gm, *sharding_spec_dicts)
MetaInfoProp(gm).run()
gm = ModuleWrapper(gm, *sharding_spec_dicts)
ckpt_solver = CheckpointSolverRotor(gm.module.graph, 8 * 1024**3)
gm.module.graph = ckpt_solver.solve()
ckpt_solver.print_sequence()
gm.module.recompile()
print(gm.module)
logger.info("*******************strategy selected*******************", ranks=[0])
strategies_list = solution
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
logger.info(node.name, ranks=[0])
logger.info(node.strategies_vector[strategies_list[index]].name, ranks=[0])
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
torch.cuda.synchronize()
gm.train()
for n in range(10):
# we just use randomly generated data here
input_sample, _ = data_gen_mlp(model_cls)
mem_stamp0 = torch.cuda.memory_allocated(device='cuda:0') / 1024**2
optimizer.zero_grad()
start = time()
loss = gm(*input_sample)
loss.sum().backward()
optimizer.step()
# prof.step()
torch.cuda.synchronize()
step_time = time() - start
logger.info(f"===============Round {n}===============", ranks=[0])
logger.info(
f"Peak Memory: {torch.cuda.max_memory_allocated(device='cuda:0') / 1024**2 - mem_stamp0} MB, Step Time: {step_time:.3f}s",
ranks=[0])
torch.cuda.synchronize()
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.skipif(NON_CODEGEN, reason='codegen is not available')
@pytest.mark.dist
@rerun_if_address_is_in_use()
@parameterize('model_cls', [GPT2MLP, GPT2Attention, GPT2Block, GPT2Model])
def test_2stage_solver_on_gpt(model_cls):
spawn(
check_2stage_solver_on_gpt,
4,
model_cls=model_cls,
)
if __name__ == '__main__':
test_2stage_solver_on_gpt()

View File

@ -0,0 +1,138 @@
from time import time
import psutil
import pytest
import torch
import torchvision.models as tm
try:
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
NON_CODEGEN = False
except:
NON_CODEGEN = True
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.checkpoint.ckpt_solver_rotor import CheckpointSolverRotor
from colossalai.auto_parallel.passes.comm_metainfo_pass import comm_metainfo_pass
from colossalai.auto_parallel.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.tensor_shard.initialize import (
ModuleWrapper,
build_strategy_constructor,
solve_solution,
transform_to_sharded_model,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch, launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
BATCH_SIZE = 256
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def data_gen_resnet(batch_size, shape):
"""
Generate random data for resnet benchmarking
"""
data = torch.empty(batch_size, *shape, device=torch.cuda.current_device())
label = torch.empty(batch_size, dtype=torch.long, device=torch.cuda.current_device()).random_(1000)
return data, label
def check_2stage_solver_on_resnet(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
model = tm.resnet50().cuda()
meta_input_sample = {
'x': torch.randn(BATCH_SIZE * 4, 3, 224, 224, device='meta'),
}
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer(bias_addition_split=True, trace_act_ckpt=True)
graph = tracer.trace(root=model, meta_args=meta_input_sample)
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
shape_prop_pass(gm, *meta_input_sample.values())
gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh,
strategies_constructor)
comm_metainfo_pass(gm, *sharding_spec_dicts)
MetaInfoProp(gm).run()
gm = ModuleWrapper(gm, *sharding_spec_dicts)
ckpt_solver = CheckpointSolverRotor(gm.module.graph, 8 * 1024**3)
gm.module.graph = ckpt_solver.solve()
ckpt_solver.print_sequence()
logger.info("*******************strategy selected*******************", ranks=[0])
strategies_list = solution
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
logger.info(node.name, ranks=[0])
logger.info(node.strategies_vector[strategies_list[index]].name, ranks=[0])
# build criterion
criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
torch.cuda.synchronize()
model.train()
for n in range(10):
# we just use randomly generated data here
data, label = data_gen_resnet(BATCH_SIZE * 4, (3, 224, 224))
mem_stamp0 = torch.cuda.memory_allocated(device='cuda:0') / 1024**2
optimizer.zero_grad()
start = time()
outputs = gm(data)
loss = criterion(outputs, label)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
step_time = time() - start
logger.info(f"===============Round {n}===============", ranks=[0])
logger.info(
f"Peak Memory: {torch.cuda.max_memory_allocated(device='cuda:0') / 1024**2 - mem_stamp0} MB, Step Time: {step_time:.3f}s",
ranks=[0])
torch.cuda.synchronize()
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.skipif(NON_CODEGEN, reason='codegen is not available')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_2stage_solver_on_resnet():
spawn(
check_2stage_solver_on_resnet,
4,
)
if __name__ == '__main__':
test_2stage_solver_on_resnet()

View File

@ -162,7 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler():
spawn(check_linear_module_handler)
spawn(check_linear_module_handler, 4)
if __name__ == '__main__':

View File

@ -151,7 +151,7 @@ def check_linear_module_handler(rank, world_size, port, bias):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(bias=True):
spawn(check_linear_module_handler, bias=bias)
spawn(check_linear_module_handler, 4, bias=bias)
if __name__ == '__main__':

View File

@ -100,7 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
def test_getitem_from_tensor_handler(getitem_index):
spawn(check_getitem_from_tensor_handler, 4)
spawn(check_getitem_from_tensor_handler, 4, getitem_index=getitem_index)
class GetItemFromTupleModel(nn.Module):