mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 19:55:03 +00:00
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> * Split conv2d, class token, positional embedding in 2d, Fix random number in ddp Fix convergence in cifar10, Imagenet1000 * Integrate 1d tensor parallel in Colossal-AI (#39) * fixed 1D and 2D convergence (#38) * optimized 2D operations * fixed 1D ViT convergence problem * Feature/ddp (#49) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * support torch ddp * fix loss accumulation * add log for ddp * change seed * modify timing hook Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * Feature/pipeline (#40) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * optimize communication of pipeline parallel * fix grad clip for pipeline Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51) * Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset * update api for better usability (#58) update api for better usability Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
232
tests/test_trainer/test_pipeline/debug_schedule.py
Normal file
232
tests/test_trainer/test_pipeline/debug_schedule.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# referenced from Megatron and used to testify communication
|
||||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.builder import ModelInitializer, build_dataset, build_optimizer, build_loss
|
||||
from colossalai.communication import p2p as p2p_communication
|
||||
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import initialize
|
||||
from colossalai.utils import print_rank_0, get_current_device
|
||||
|
||||
NUM_BATCH = 128
|
||||
NUM_MICRO = 6
|
||||
|
||||
|
||||
def get_num_microbatches():
|
||||
return NUM_MICRO
|
||||
|
||||
|
||||
def to_cuda(data):
|
||||
if isinstance(data, (tuple, list)):
|
||||
data = data[0].to(get_current_device())
|
||||
else:
|
||||
data = data.to(get_current_device())
|
||||
return data
|
||||
|
||||
|
||||
def step_func(loss):
|
||||
def _step_func(input_tensor, model):
|
||||
output = model(input_tensor)
|
||||
if isinstance(output, (tuple, list)):
|
||||
if len(output) > 1:
|
||||
raise NotImplementedError("Multiple output!!!")
|
||||
else:
|
||||
output = output[0]
|
||||
return output, loss
|
||||
|
||||
return _step_func
|
||||
|
||||
|
||||
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
|
||||
"""Forward step for passed-in model.
|
||||
If first stage, input tensor is obtained from data_iterator, otherwise
|
||||
passed-in input_tensor is used.
|
||||
Returns output tensor."""
|
||||
|
||||
if input_tensor is None:
|
||||
data, label = data_iterator.next()
|
||||
input_tensor = to_cuda(data)
|
||||
|
||||
output_tensor, loss_func = forward_step_func(input_tensor, model)
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
data, label = data_iterator.next()
|
||||
label = to_cuda(label)
|
||||
output_tensor = loss_func(output_tensor, label) / get_num_microbatches()
|
||||
losses_reduced.append(output_tensor)
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
|
||||
"""Backward step through passed-in output tensor.
|
||||
If last stage, output_tensor_grad is None, otherwise gradient of loss
|
||||
with respect to stage's output tensor.
|
||||
Returns gradient of loss with respect to input tensor (None if first
|
||||
stage)."""
|
||||
|
||||
# Retain the grad on the input_tensor.
|
||||
if input_tensor is not None:
|
||||
input_tensor.retain_grad()
|
||||
|
||||
# Backward pass.
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||
|
||||
# Collect the grad of the input_tensor.
|
||||
input_tensor_grad = None
|
||||
if input_tensor is not None:
|
||||
input_tensor_grad = input_tensor.grad
|
||||
|
||||
return input_tensor_grad
|
||||
|
||||
|
||||
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
|
||||
model, optimizer, forward_only):
|
||||
"""Run non-interleaved 1F1B schedule, with communication between pipeline
|
||||
stages.
|
||||
Returns dictionary with losses if the last stage, empty dict otherwise."""
|
||||
|
||||
# Compute number of warmup microbatches.
|
||||
num_microbatches = get_num_microbatches()
|
||||
num_warmup_microbatches = \
|
||||
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
||||
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||
num_warmup_microbatches = min(
|
||||
num_warmup_microbatches,
|
||||
num_microbatches)
|
||||
num_microbatches_remaining = \
|
||||
num_microbatches - num_warmup_microbatches
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_tensors = None
|
||||
output_tensors = None
|
||||
if not forward_only:
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
losses_reduced = []
|
||||
|
||||
# Used for tensor meta information communication
|
||||
ft_shape = None
|
||||
bt_shape = None
|
||||
fs_checker = True
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = p2p_communication.recv_forward(ft_shape)
|
||||
output_tensor = forward_step(forward_step_func, data_iterator, model,
|
||||
input_tensor, losses_reduced)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
||||
p2p_communication.send_forward(output_tensor)
|
||||
|
||||
if not forward_only:
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = p2p_communication.recv_forward(ft_shape)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_tensor = forward_step(forward_step_func, data_iterator, model,
|
||||
input_tensor, losses_reduced)
|
||||
if forward_only:
|
||||
p2p_communication.send_forward(output_tensor)
|
||||
|
||||
if not last_iteration:
|
||||
input_tensor = p2p_communication.recv_forward(ft_shape)
|
||||
|
||||
else:
|
||||
output_tensor_grad = \
|
||||
p2p_communication.send_forward_recv_backward(output_tensor, bt_shape)
|
||||
|
||||
# Add input_tensor and output_tensor to end of list.
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
# Pop input_tensor and output_tensor from the start of the list for
|
||||
# the backward pass.
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
input_tensor_grad = \
|
||||
backward_step(optimizer, input_tensor, output_tensor,
|
||||
output_tensor_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
p2p_communication.send_backward(input_tensor_grad)
|
||||
else:
|
||||
input_tensor = \
|
||||
p2p_communication.send_backward_recv_forward(input_tensor_grad, ft_shape)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
output_tensor_grad = p2p_communication.recv_backward(bt_shape)
|
||||
|
||||
input_tensor_grad = \
|
||||
backward_step(optimizer, input_tensor, output_tensor,
|
||||
output_tensor_grad)
|
||||
|
||||
p2p_communication.send_backward(input_tensor_grad)
|
||||
|
||||
return losses_reduced
|
||||
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_vit.py')
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This is only for debugging purpose, please ignore this test")
|
||||
@pytest.mark.dist
|
||||
def test_schedule():
|
||||
initialize(CONFIG_PATH)
|
||||
|
||||
# build model
|
||||
model = ModelInitializer(gpc.config.model, 1).model_initialize()
|
||||
print_rank_0('model is created')
|
||||
|
||||
# keep the same sampler for all process
|
||||
torch.manual_seed(1331)
|
||||
|
||||
dataset = build_dataset(gpc.config.data.dataset)
|
||||
dataloader = DataLoader(dataset=dataset, **gpc.config.data.dataloader)
|
||||
print_rank_0('train data is created')
|
||||
|
||||
# build optimizer and loss
|
||||
optim = build_optimizer(gpc.config.optimizer, model)
|
||||
loss = build_loss(gpc.config.loss)
|
||||
print_rank_0('optim and loss is created')
|
||||
|
||||
forward_backward_pipelining_without_interleaving(
|
||||
step_func(loss),
|
||||
iter(dataloader),
|
||||
model,
|
||||
optim,
|
||||
False
|
||||
)
|
||||
|
||||
gpc.destroy()
|
||||
print_rank_0('training finished')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_schedule()
|
149
tests/test_trainer/test_pipeline/test_p2p.py
Normal file
149
tests/test_trainer/test_pipeline/test_p2p.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.communication import (recv_backward, recv_forward,
|
||||
recv_tensor_meta, send_backward,
|
||||
send_backward_recv_forward, send_forward,
|
||||
send_forward_recv_backward,
|
||||
send_tensor_meta)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import init_dist, parse_args
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
BATCH_SIZE = 32
|
||||
SEQ_LENGTH = 128
|
||||
HIDDEN_SIZE = 512
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=dict(size=4),
|
||||
tensor=dict(size=1, mode=None)
|
||||
),
|
||||
seed=1024
|
||||
)
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
return torch.allclose(A, B, rtol=1e-5, atol=1e-3)
|
||||
|
||||
|
||||
def check_forward(output_tensor, rank, logger):
|
||||
dist.barrier()
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
tensor = output_tensor.clone()
|
||||
else:
|
||||
tensor = recv_forward(output_tensor.shape)
|
||||
logger.info('Rank {} received forward. Correct tensor: {}'.format(
|
||||
rank, check_equal(tensor, output_tensor)))
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
send_forward(tensor)
|
||||
logger.info('Rank {} sent forward.'.format(rank))
|
||||
|
||||
|
||||
def check_backward(output_grad, rank, logger):
|
||||
dist.barrier()
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
grad = output_grad.clone()
|
||||
else:
|
||||
grad = recv_backward(output_grad.shape)
|
||||
logger.info('Rank {} received backward. Correct grad: {}'.format(
|
||||
rank, check_equal(grad, output_grad)))
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
send_backward(grad)
|
||||
logger.info('Rank {} sent backward.'.format(rank))
|
||||
|
||||
|
||||
def check_forward_backward(output_tensor, output_grad, rank, logger):
|
||||
dist.barrier()
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
tensor = send_backward_recv_forward(output_grad, output_tensor.shape)
|
||||
logger.info(
|
||||
'Rank {} sent backward received forward. Correct tensor: {}'.
|
||||
format(rank, check_equal(tensor, output_tensor)))
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
grad = send_forward_recv_backward(output_tensor, output_grad.shape)
|
||||
logger.info(
|
||||
'Rank {} sent forward received backward. Correct grad: {}'.format(
|
||||
rank, check_equal(grad, output_grad)))
|
||||
|
||||
|
||||
def check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||
dtype = torch.float32
|
||||
device = get_current_device()
|
||||
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
# recv_tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(tensor)
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(grad)
|
||||
if rank % 2 == 0:
|
||||
need_meta = True
|
||||
need_meta = send_tensor_meta(tensor, need_meta)
|
||||
logger.info('Rank {} shape sent (need meta: {}).'.format(
|
||||
rank, need_meta))
|
||||
req = dist.broadcast(tensor, src=rank, group=down_group, async_op=True)
|
||||
req.wait()
|
||||
out = tensor.clone()
|
||||
logger.info('Rank {} test op: tensor sent.'.format(rank))
|
||||
else:
|
||||
recv_tensor_shape = recv_tensor_meta(None)
|
||||
logger.info('Rank {} shape received. Correct shape: {}'.format(
|
||||
rank, tensor_shape == recv_tensor_shape))
|
||||
out = torch.empty(recv_tensor_shape, dtype=dtype, device=device)
|
||||
req = dist.broadcast(out, src=prev_rank, group=up_group, async_op=True)
|
||||
req.wait()
|
||||
logger.info('Rank {} test op: received tensor ({})'.format(
|
||||
rank, out.shape))
|
||||
|
||||
logger.info('Rank {} test op. Correct tensor: {}'.format(
|
||||
rank, check_equal(tensor, out)))
|
||||
|
||||
|
||||
def test_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||
dtype = torch.float32
|
||||
device = get_current_device()
|
||||
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(tensor)
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(grad)
|
||||
check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger)
|
||||
check_forward(tensor, rank, logger)
|
||||
check_backward(grad, rank, logger)
|
||||
check_forward_backward(tensor, grad, rank, logger)
|
||||
|
||||
|
||||
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
||||
@pytest.mark.dist
|
||||
def test_main():
|
||||
args = parse_args()
|
||||
world_size = args.world_size
|
||||
|
||||
init_dist(CONFIG)
|
||||
logger = get_dist_logger()
|
||||
rank = gpc.get_global_rank()
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
|
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
down_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_NEXT)
|
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
logger.info(
|
||||
'Rank {0}: prev rank {1} (up: {2}), next rank {3} (down: {4})'.format(
|
||||
rank, prev_rank, up_ranks, next_rank, down_ranks))
|
||||
logger.info('Distributed environment is initialzied.')
|
||||
|
||||
test_comm(world_size, rank, prev_rank, next_rank, up_group, down_group,
|
||||
logger)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_main()
|
37
tests/test_trainer/test_pipeline/test_partition.py
Normal file
37
tests/test_trainer/test_pipeline/test_partition.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.builder import build_dataset, ModelInitializer
|
||||
from colossalai.core import global_context
|
||||
from colossalai.initialize import init_dist
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
||||
|
||||
|
||||
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
||||
@pytest.mark.dist
|
||||
def test_partition():
|
||||
init_dist(CONFIG_PATH)
|
||||
logger = get_dist_logger()
|
||||
logger.info('finished initialization')
|
||||
|
||||
# build model
|
||||
model = ModelInitializer(global_context.config.model, 1, verbose=True).model_initialize()
|
||||
logger.info('model is created')
|
||||
|
||||
dataset = build_dataset(global_context.config.train_data.dataset)
|
||||
dataloader = DataLoader(dataset=dataset, **global_context.config.train_data.dataloader)
|
||||
logger.info('train data is created')
|
||||
|
||||
global_context.destroy()
|
||||
torch.cuda.synchronize()
|
||||
logger.info('training finished')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_partition()
|
51
tests/test_trainer/test_pipeline/test_schedule.py
Normal file
51
tests/test_trainer/test_pipeline/test_schedule.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import initialize
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
NUM_BATCH = 128
|
||||
|
||||
BATCH_SIZE = 32
|
||||
SEQ_LENGTH = 128
|
||||
HIDDEN_SIZE = 512
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
||||
|
||||
|
||||
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
||||
@pytest.mark.dist
|
||||
def test_schedule():
|
||||
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
|
||||
logger = get_dist_logger()
|
||||
|
||||
model = engine.model
|
||||
optimizer = engine.optimizer
|
||||
criterion = engine.criterion
|
||||
schedule = engine._schedule
|
||||
|
||||
output, label, loss = schedule.forward_backward_step(
|
||||
data_iter=iter(train_dataloader),
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
forward_only=False
|
||||
)
|
||||
schedule.optimizer_step(model, optimizer)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
logger.info('losses: {}'.format(loss))
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('training finished')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_schedule()
|
Reference in New Issue
Block a user