mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
[autoparallel] integrate 2-stage solver (#3476)
* integrate 2-stage solver * [autoparallel] integrate 2-stage solver * polish
This commit is contained in:
parent
190a6ea9c2
commit
5458da5c3c
@ -94,6 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
if 'info' not in node.meta:
|
||||
continue
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
||||
|
||||
|
@ -5,12 +5,12 @@ from typing import Any, List
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
||||
runtime_apply,
|
||||
runtime_apply_for_iterable_object,
|
||||
runtime_comm_spec_apply,
|
||||
)
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
|
||||
__all___ = ['CheckpointSolverBase']
|
||||
|
||||
|
@ -369,7 +369,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
||||
n.meta['info'].activation_checkpoint = [ckpt_idx]
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
@ -377,7 +377,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
||||
n.meta['info'].activation_checkpoint = [ckpt_idx]
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [idx]
|
||||
@ -397,7 +397,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
elif isinstance(op, ForwardEnable):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
n.meta['info'].activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
@ -405,7 +405,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
n.meta['info'].activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
@ -413,7 +413,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
elif isinstance(op, Backward):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
n.meta['info'].activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
in_recompute = False
|
||||
|
||||
@ -433,7 +433,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||
for (start_idx, end_idx) in ckpt_regions:
|
||||
nested_length = max(
|
||||
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
|
||||
len(op_list[idx].meta['info'].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
|
||||
len(op_list[idx].meta['activation_checkpoint']))
|
||||
op_list[idx].meta['info'].activation_checkpoint += [None] * (
|
||||
nested_length - len(op_list[idx].meta['info'].activation_checkpoint))
|
||||
|
@ -68,7 +68,7 @@ class MetaInfoProp:
|
||||
graph_info = GraphInfo()
|
||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||
node.meta = {**asdict(graph_info)}
|
||||
node.meta.update(asdict(graph_info))
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_attr_handler(self, node: Node) -> None:
|
||||
@ -76,7 +76,7 @@ class MetaInfoProp:
|
||||
Handle the get_attr node.
|
||||
"""
|
||||
graph_info = GraphInfo()
|
||||
node.meta = {**asdict(graph_info)}
|
||||
node.meta.update(asdict(graph_info))
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def output_handler(self, node: Node) -> None:
|
||||
@ -89,15 +89,36 @@ class MetaInfoProp:
|
||||
if par.meta:
|
||||
output_tensors += par.meta["fwd_out"]
|
||||
graph_info.fwd_in = output_tensors
|
||||
node.meta = {**asdict(graph_info)}
|
||||
node.meta.update(asdict(graph_info))
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def node_handler(self, node: Node) -> None:
|
||||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
if not hasattr(node, 'best_strategy_info'):
|
||||
# attach them to graph_info
|
||||
graph_info.fwd_in = []
|
||||
graph_info.fwd_tmp = []
|
||||
graph_info.fwd_out = []
|
||||
|
||||
# fetch other memory informations
|
||||
graph_info.fwd_mem_tmp = 10
|
||||
graph_info.fwd_mem_out = 10
|
||||
graph_info.bwd_mem_tmp = 10
|
||||
graph_info.bwd_mem_out = 10
|
||||
|
||||
# fetch flop information
|
||||
# here we use fwd_time and bwd_time to deal with the case that
|
||||
# communication cost is a float
|
||||
graph_info.fwd_time = 10
|
||||
graph_info.bwd_time = 10
|
||||
node.meta.update(asdict(graph_info))
|
||||
# print(node.name, [isinstance(arg, torch.Tensor) for arg in node.args], isinstance(node._meta_data, torch.Tensor))
|
||||
return
|
||||
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
|
||||
meta_info = node.best_strategy_info
|
||||
meta_info: ShardMetaInfo
|
||||
|
||||
@ -124,6 +145,8 @@ class MetaInfoProp:
|
||||
for par in node._input_nodes:
|
||||
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||
for tensor in par.meta.get("fwd_out", []):
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
continue
|
||||
tensor: torch.Tensor
|
||||
target_input_tensor = next(
|
||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||
@ -161,5 +184,4 @@ class MetaInfoProp:
|
||||
compute_cost = meta_info.compute_cost
|
||||
graph_info.fwd_time = compute_cost.fwd
|
||||
graph_info.bwd_time = compute_cost.bwd
|
||||
|
||||
node.meta = {**asdict(graph_info)}
|
||||
node.meta.update(asdict(graph_info))
|
||||
|
@ -157,9 +157,22 @@ class NodeHandler(ABC):
|
||||
Register different sharding strategies for the current node.
|
||||
"""
|
||||
strategy_generators = self.get_strategy_generator()
|
||||
strategies_info = []
|
||||
for generator in strategy_generators:
|
||||
strategies = generator.generate()
|
||||
|
||||
for strategy in strategies:
|
||||
shard_metainfo = ShardMetaInfo()
|
||||
shard_metainfo.compute_cost = strategy.compute_cost
|
||||
shard_metainfo.memory_cost = strategy.memory_cost
|
||||
shard_metainfo.fwd_in = []
|
||||
if isinstance(self.node._meta_data, torch.Tensor):
|
||||
shard_metainfo.fwd_out = [self.node._meta_data]
|
||||
else:
|
||||
shard_metainfo.fwd_out = self.node._meta_data
|
||||
shard_metainfo.fwd_buffer = []
|
||||
strategies_info.append(shard_metainfo)
|
||||
|
||||
# postprocess a strategy
|
||||
# postprocess can produce one strategy or multiple strategies
|
||||
post_processed_strategies_map = map(self.post_process, strategies)
|
||||
|
@ -102,7 +102,6 @@ class StrategiesConstructor:
|
||||
|
||||
if _check_no_strategy_for_node(node):
|
||||
self.no_strategy_nodes.append(node)
|
||||
pass
|
||||
|
||||
# placeholder node
|
||||
elif node.op == 'placeholder':
|
||||
|
@ -0,0 +1,174 @@
|
||||
from time import time
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||
NON_CODEGEN = False
|
||||
except:
|
||||
NON_CODEGEN = True
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.checkpoint.ckpt_solver_rotor import CheckpointSolverRotor
|
||||
from colossalai.auto_parallel.passes.comm_metainfo_pass import comm_metainfo_pass
|
||||
from colossalai.auto_parallel.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import (
|
||||
ModuleWrapper,
|
||||
build_strategy_constructor,
|
||||
solve_solution,
|
||||
transform_to_sharded_model,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch, launch_from_torch
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
BATCH_SIZE = 16
|
||||
SEQ_LENGTH = 1024
|
||||
HIDDEN_DIM = 2048
|
||||
NUM_HEADS = 16
|
||||
NUM_LAYERS = 2
|
||||
VOCAB_SIZE = 50257
|
||||
NUM_STEPS = 10
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
||||
def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
|
||||
|
||||
def data_gen_mlp(model_cls):
|
||||
"""
|
||||
Generate random data for resnet benchmarking
|
||||
"""
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32)
|
||||
|
||||
if model_cls == GPT2MLP:
|
||||
input_sample = (hidden_states.to('cuda'),)
|
||||
meta_input_sample = {
|
||||
'hidden_states': hidden_states.to('meta'),
|
||||
}
|
||||
elif model_cls in (GPT2Attention, GPT2Block):
|
||||
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
input_sample = (
|
||||
hidden_states.to('cuda'),
|
||||
attention_mask.to('cuda'),
|
||||
)
|
||||
meta_input_sample = {
|
||||
'hidden_states': hidden_states.to('meta'),
|
||||
'attention_mask': attention_mask.to('meta'),
|
||||
}
|
||||
else:
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
input_sample = (
|
||||
input_ids.to('cuda'),
|
||||
attention_mask.to('cuda'),
|
||||
)
|
||||
meta_input_sample = {
|
||||
'input_ids': input_ids.to('meta'),
|
||||
'attention_mask': attention_mask.to('meta'),
|
||||
}
|
||||
return input_sample, meta_input_sample
|
||||
|
||||
|
||||
def check_2stage_solver_on_gpt(rank, world_size, port, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||
logger = get_dist_logger()
|
||||
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
|
||||
if model_cls == GPT2MLP:
|
||||
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
|
||||
else:
|
||||
model = model_cls(config=config).to('cuda')
|
||||
|
||||
input_sample, meta_input_sample = data_gen_mlp(model_cls)
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_input_sample)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
shape_prop_pass(gm, *meta_input_sample.values())
|
||||
gm.recompile()
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
|
||||
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh,
|
||||
strategies_constructor)
|
||||
comm_metainfo_pass(gm, *sharding_spec_dicts)
|
||||
|
||||
MetaInfoProp(gm).run()
|
||||
|
||||
gm = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
ckpt_solver = CheckpointSolverRotor(gm.module.graph, 8 * 1024**3)
|
||||
gm.module.graph = ckpt_solver.solve()
|
||||
ckpt_solver.print_sequence()
|
||||
gm.module.recompile()
|
||||
print(gm.module)
|
||||
logger.info("*******************strategy selected*******************", ranks=[0])
|
||||
strategies_list = solution
|
||||
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
for index, node in enumerate(nodes):
|
||||
logger.info(node.name, ranks=[0])
|
||||
logger.info(node.strategies_vector[strategies_list[index]].name, ranks=[0])
|
||||
|
||||
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
torch.cuda.synchronize()
|
||||
gm.train()
|
||||
|
||||
for n in range(10):
|
||||
# we just use randomly generated data here
|
||||
input_sample, _ = data_gen_mlp(model_cls)
|
||||
mem_stamp0 = torch.cuda.memory_allocated(device='cuda:0') / 1024**2
|
||||
optimizer.zero_grad()
|
||||
start = time()
|
||||
loss = gm(*input_sample)
|
||||
loss.sum().backward()
|
||||
optimizer.step()
|
||||
# prof.step()
|
||||
torch.cuda.synchronize()
|
||||
step_time = time() - start
|
||||
logger.info(f"===============Round {n}===============", ranks=[0])
|
||||
logger.info(
|
||||
f"Peak Memory: {torch.cuda.max_memory_allocated(device='cuda:0') / 1024**2 - mem_stamp0} MB, Step Time: {step_time:.3f}s",
|
||||
ranks=[0])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.skipif(NON_CODEGEN, reason='codegen is not available')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@parameterize('model_cls', [GPT2MLP, GPT2Attention, GPT2Block, GPT2Model])
|
||||
def test_2stage_solver_on_gpt(model_cls):
|
||||
spawn(
|
||||
check_2stage_solver_on_gpt,
|
||||
4,
|
||||
model_cls=model_cls,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_2stage_solver_on_gpt()
|
@ -0,0 +1,138 @@
|
||||
from time import time
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||
NON_CODEGEN = False
|
||||
except:
|
||||
NON_CODEGEN = True
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.checkpoint.ckpt_solver_rotor import CheckpointSolverRotor
|
||||
from colossalai.auto_parallel.passes.comm_metainfo_pass import comm_metainfo_pass
|
||||
from colossalai.auto_parallel.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import (
|
||||
ModuleWrapper,
|
||||
build_strategy_constructor,
|
||||
solve_solution,
|
||||
transform_to_sharded_model,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch, launch_from_torch
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
BATCH_SIZE = 256
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
||||
def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
|
||||
|
||||
def data_gen_resnet(batch_size, shape):
|
||||
"""
|
||||
Generate random data for resnet benchmarking
|
||||
"""
|
||||
data = torch.empty(batch_size, *shape, device=torch.cuda.current_device())
|
||||
label = torch.empty(batch_size, dtype=torch.long, device=torch.cuda.current_device()).random_(1000)
|
||||
return data, label
|
||||
|
||||
|
||||
def check_2stage_solver_on_resnet(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
logger = get_dist_logger()
|
||||
model = tm.resnet50().cuda()
|
||||
|
||||
meta_input_sample = {
|
||||
'x': torch.randn(BATCH_SIZE * 4, 3, 224, 224, device='meta'),
|
||||
}
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
tracer = ColoTracer(bias_addition_split=True, trace_act_ckpt=True)
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_input_sample)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
shape_prop_pass(gm, *meta_input_sample.values())
|
||||
gm.recompile()
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
|
||||
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh,
|
||||
strategies_constructor)
|
||||
comm_metainfo_pass(gm, *sharding_spec_dicts)
|
||||
MetaInfoProp(gm).run()
|
||||
gm = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
|
||||
ckpt_solver = CheckpointSolverRotor(gm.module.graph, 8 * 1024**3)
|
||||
gm.module.graph = ckpt_solver.solve()
|
||||
ckpt_solver.print_sequence()
|
||||
|
||||
logger.info("*******************strategy selected*******************", ranks=[0])
|
||||
strategies_list = solution
|
||||
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
for index, node in enumerate(nodes):
|
||||
logger.info(node.name, ranks=[0])
|
||||
logger.info(node.strategies_vector[strategies_list[index]].name, ranks=[0])
|
||||
|
||||
# build criterion
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
|
||||
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
for n in range(10):
|
||||
# we just use randomly generated data here
|
||||
data, label = data_gen_resnet(BATCH_SIZE * 4, (3, 224, 224))
|
||||
mem_stamp0 = torch.cuda.memory_allocated(device='cuda:0') / 1024**2
|
||||
optimizer.zero_grad()
|
||||
start = time()
|
||||
outputs = gm(data)
|
||||
loss = criterion(outputs, label)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
torch.cuda.synchronize()
|
||||
step_time = time() - start
|
||||
logger.info(f"===============Round {n}===============", ranks=[0])
|
||||
logger.info(
|
||||
f"Peak Memory: {torch.cuda.max_memory_allocated(device='cuda:0') / 1024**2 - mem_stamp0} MB, Step Time: {step_time:.3f}s",
|
||||
ranks=[0])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.skipif(NON_CODEGEN, reason='codegen is not available')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2stage_solver_on_resnet():
|
||||
spawn(
|
||||
check_2stage_solver_on_resnet,
|
||||
4,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_2stage_solver_on_resnet()
|
@ -162,7 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler():
|
||||
spawn(check_linear_module_handler)
|
||||
spawn(check_linear_module_handler, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -151,7 +151,7 @@ def check_linear_module_handler(rank, world_size, port, bias):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler(bias=True):
|
||||
spawn(check_linear_module_handler, bias=bias)
|
||||
spawn(check_linear_module_handler, 4, bias=bias)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -100,7 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
|
||||
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
|
||||
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
|
||||
def test_getitem_from_tensor_handler(getitem_index):
|
||||
spawn(check_getitem_from_tensor_handler, 4)
|
||||
spawn(check_getitem_from_tensor_handler, 4, getitem_index=getitem_index)
|
||||
|
||||
|
||||
class GetItemFromTupleModel(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user