mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 06:00:44 +00:00
[examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo * add test_ci.sh * polish * add conda yaml
This commit is contained in:
parent
9358262992
commit
c20529fe78
@ -4,23 +4,14 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from titans.utils import barrier_context
|
from titans.utils import barrier_context
|
||||||
from torch.fx import GraphModule
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from torchvision.models import resnet50
|
from torchvision.models import resnet50
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
|
||||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
||||||
from colossalai.utils import get_dataloader
|
from colossalai.utils import get_dataloader
|
||||||
@ -28,12 +19,6 @@ from colossalai.utils import get_dataloader
|
|||||||
DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
|
DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def synthesize_data():
|
def synthesize_data():
|
||||||
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
|
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
|
||||||
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
|
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
|
||||||
@ -41,82 +26,15 @@ def synthesize_data():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
|
||||||
colossalai.launch_from_torch(config='./config.py')
|
colossalai.launch_from_torch(config='./config.py')
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
if not args.synthetic:
|
|
||||||
with barrier_context():
|
|
||||||
# build dataloaders
|
|
||||||
train_dataset = CIFAR10(root=DATA_ROOT,
|
|
||||||
download=True,
|
|
||||||
transform=transforms.Compose([
|
|
||||||
transforms.RandomCrop(size=32, padding=4),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
|
|
||||||
std=[0.2023, 0.1994, 0.2010]),
|
|
||||||
]))
|
|
||||||
|
|
||||||
test_dataset = CIFAR10(root=DATA_ROOT,
|
|
||||||
train=False,
|
|
||||||
transform=transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
|
|
||||||
]))
|
|
||||||
|
|
||||||
train_dataloader = get_dataloader(
|
|
||||||
dataset=train_dataset,
|
|
||||||
add_sampler=True,
|
|
||||||
shuffle=True,
|
|
||||||
batch_size=gpc.config.BATCH_SIZE,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_dataloader = get_dataloader(
|
|
||||||
dataset=test_dataset,
|
|
||||||
add_sampler=True,
|
|
||||||
batch_size=gpc.config.BATCH_SIZE,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
train_dataloader, test_dataloader = None, None
|
|
||||||
|
|
||||||
# initialize device mesh
|
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
|
||||||
mesh_shape = (2, 2)
|
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
|
||||||
|
|
||||||
# trace the model with meta data
|
# trace the model with meta data
|
||||||
tracer = ColoTracer()
|
|
||||||
model = resnet50(num_classes=10).cuda()
|
model = resnet50(num_classes=10).cuda()
|
||||||
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
# prepare info for solver
|
|
||||||
solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
|
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
|
||||||
strategies_constructor.build_strategies_and_cost()
|
|
||||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
|
||||||
cost_graph.simplify_graph()
|
|
||||||
graph_analyser = GraphAnalyser(gm)
|
|
||||||
|
|
||||||
# solve the solution
|
|
||||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
|
||||||
ret = solver.call_solver_serialized_args()
|
|
||||||
solution = list(ret[0])
|
|
||||||
if gpc.get_global_rank() == 0:
|
|
||||||
for index, node in enumerate(graph.nodes):
|
|
||||||
print(node.name, node.strategies_vector[solution[index]].name)
|
|
||||||
|
|
||||||
# process the graph for distributed training ability
|
|
||||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
|
|
||||||
gm = runtime_apply_pass(gm)
|
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
|
model = autoparallelize(model, input_sample)
|
||||||
# build criterion
|
# build criterion
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
@ -127,65 +45,47 @@ def main():
|
|||||||
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
|
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
|
||||||
|
|
||||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||||
gm.train()
|
model.train()
|
||||||
|
|
||||||
if args.synthetic:
|
# if we use synthetic data
|
||||||
# if we use synthetic data
|
# we assume it only has 30 steps per epoch
|
||||||
# we assume it only has 30 steps per epoch
|
num_steps = range(30)
|
||||||
num_steps = range(30)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# we use the actual number of steps for training
|
|
||||||
num_steps = range(len(train_dataloader))
|
|
||||||
data_iter = iter(train_dataloader)
|
|
||||||
progress = tqdm(num_steps)
|
progress = tqdm(num_steps)
|
||||||
|
|
||||||
for _ in progress:
|
for _ in progress:
|
||||||
if args.synthetic:
|
# generate fake data
|
||||||
# generate fake data
|
img, label = synthesize_data()
|
||||||
img, label = synthesize_data()
|
|
||||||
else:
|
|
||||||
# get the real data
|
|
||||||
img, label = next(data_iter)
|
|
||||||
|
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
label = label.cuda()
|
label = label.cuda()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
output = model(img)
|
||||||
train_loss = criterion(output, label)
|
train_loss = criterion(output, label)
|
||||||
train_loss.backward(train_loss)
|
train_loss.backward(train_loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
# run evaluation
|
# run evaluation
|
||||||
gm.eval()
|
model.eval()
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
|
||||||
if args.synthetic:
|
# if we use synthetic data
|
||||||
# if we use synthetic data
|
# we assume it only has 10 steps for evaluation
|
||||||
# we assume it only has 10 steps for evaluation
|
num_steps = range(30)
|
||||||
num_steps = range(30)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# we use the actual number of steps for training
|
|
||||||
num_steps = range(len(test_dataloader))
|
|
||||||
data_iter = iter(test_dataloader)
|
|
||||||
progress = tqdm(num_steps)
|
progress = tqdm(num_steps)
|
||||||
|
|
||||||
for _ in progress:
|
for _ in progress:
|
||||||
if args.synthetic:
|
# generate fake data
|
||||||
# generate fake data
|
img, label = synthesize_data()
|
||||||
img, label = synthesize_data()
|
|
||||||
else:
|
|
||||||
# get the real data
|
|
||||||
img, label = next(data_iter)
|
|
||||||
|
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
label = label.cuda()
|
label = label.cuda()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
output = model(img)
|
||||||
test_loss = criterion(output, label)
|
test_loss = criterion(output, label)
|
||||||
pred = torch.argmax(output, dim=-1)
|
pred = torch.argmax(output, dim=-1)
|
||||||
correct += torch.sum(pred == label)
|
correct += torch.sum(pred == label)
|
||||||
|
32
examples/tutorial/auto_parallel/environment.yaml
Normal file
32
examples/tutorial/auto_parallel/environment.yaml
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
name: auto
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=conda_forge
|
||||||
|
- _openmp_mutex=4.5=2_kmp_llvm
|
||||||
|
- blas=1.0=mkl
|
||||||
|
- brotlipy=0.7.0=py38h27cfd23_1003
|
||||||
|
- bzip2=1.0.8=h7b6447c_0
|
||||||
|
- ca-certificates=2022.12.7=ha878542_0
|
||||||
|
- certifi=2022.12.7=pyhd8ed1ab_0
|
||||||
|
- cffi=1.15.1=py38h74dc2b5_0
|
||||||
|
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
||||||
|
- coin-or-cbc=2.10.8=h3786ebc_0
|
||||||
|
- coin-or-cgl=0.60.6=h6f57e76_2
|
||||||
|
- coin-or-clp=1.17.7=hc56784d_2
|
||||||
|
- coin-or-osi=0.108.7=h2720bb7_2
|
||||||
|
- coin-or-utils=2.11.6=h202d8b1_2
|
||||||
|
- python=3.8.13
|
||||||
|
- pip=22.2.2
|
||||||
|
- cudatoolkit=11.3
|
||||||
|
- pytorch=1.12.1
|
||||||
|
- torchvision=0.13.1
|
||||||
|
- numpy=1.23.1
|
||||||
|
- pip:
|
||||||
|
- titans
|
||||||
|
- torch==1.12.1
|
||||||
|
- pulp==2.7.0
|
||||||
|
- datasets
|
||||||
|
- colossalai
|
13
examples/tutorial/auto_parallel/setup.py
Normal file
13
examples/tutorial/auto_parallel/setup.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='auto_parallel',
|
||||||
|
version='0.0.1',
|
||||||
|
description='',
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[
|
||||||
|
'torch',
|
||||||
|
'numpy',
|
||||||
|
'tqdm',
|
||||||
|
],
|
||||||
|
)
|
11
examples/tutorial/auto_parallel/test_ci.sh
Normal file
11
examples/tutorial/auto_parallel/test_ci.sh
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -euxo pipefail
|
||||||
|
|
||||||
|
conda init bash
|
||||||
|
conda env create -f environment.yaml
|
||||||
|
conda activate auto
|
||||||
|
cd ../../..
|
||||||
|
pip uninstall colossalai
|
||||||
|
pip install -v .
|
||||||
|
cd ./examples/tutorial/auto_parallel
|
||||||
|
colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s
|
Loading…
Reference in New Issue
Block a user