mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
moved env variables to global variables; (#215)
added branch context; added vocab parallel layers; moved split_batch from load_batch to tensor parallel embedding layers; updated gpt model; updated unit test cases; fixed few collective communicator bugs
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from .activation_checkpoint import checkpoint
|
||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
|
||||
is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate,
|
||||
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param)
|
||||
is_using_ddp, is_using_pp, is_using_sequence, model_branch_context, multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param)
|
||||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .gradient_accumulation import accumulate_gradient
|
||||
@@ -11,9 +12,9 @@ from .timer import MultiTimer, Timer
|
||||
|
||||
__all__ = [
|
||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
|
||||
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
|
||||
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'model_branch_context',
|
||||
'conditional_context', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32',
|
||||
'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize',
|
||||
'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier',
|
||||
'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
|
||||
]
|
||||
|
@@ -6,8 +6,6 @@ import socket
|
||||
import torch
|
||||
from torch._six import inf
|
||||
|
||||
import colossalai.context.parallel_mode
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
except:
|
||||
@@ -20,6 +18,7 @@ from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARA
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
@@ -62,8 +61,7 @@ def sync_model_param(model, parallel_mode):
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
for param in model.parameters():
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(
|
||||
param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def is_dp_rank_0():
|
||||
@@ -99,6 +97,15 @@ def conditional_context(context_manager, enable=True):
|
||||
yield
|
||||
|
||||
|
||||
class model_branch_context(object):
|
||||
|
||||
def __enter__(self):
|
||||
self.env_status = env.save()
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
env.load(**self.env_status)
|
||||
|
||||
|
||||
def is_model_parallel_parameter(p):
|
||||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||
|
||||
@@ -124,9 +131,10 @@ def _calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
norm += grad_norm ** norm_type
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
# ======== Gradient Clipping =========
|
||||
|
||||
|
||||
@@ -183,7 +191,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type)
|
||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
|
||||
tensor_parallel_grads.append(p.grad.data / reductor)
|
||||
elif is_moe_parallel_parameter(p):
|
||||
moe_parallel_grads.append(p.grad.data)
|
||||
@@ -191,32 +199,24 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
|
||||
if norm_type == 2.0:
|
||||
tensor_parallel_norm = _calc_l2_norm(
|
||||
tensor_parallel_grads) ** norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(
|
||||
no_tensor_parallel_grads) ** norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(
|
||||
moe_parallel_grads) ** norm_type
|
||||
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_norm = _calc_lp(
|
||||
no_tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
|
||||
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
|
||||
dist.all_reduce(tensor_parallel_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
# Sum across all moe-tensor-parallel GPUs
|
||||
if len(moe_parallel_grads) > 0:
|
||||
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
||||
no_tensor_parallel_norm += moe_parallel_norm
|
||||
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
dist.all_reduce(total_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm**(1.0 / norm_type)
|
||||
if type(total_norm) == 'torch.cuda.FloatTensor':
|
||||
total_norm = total_norm.item()
|
||||
|
||||
@@ -225,10 +225,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
if clip_coeff < 1.0:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale,
|
||||
dummy_overflow_buf,
|
||||
[grads, grads],
|
||||
clip_coeff)
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
||||
|
||||
return total_norm
|
||||
|
||||
@@ -254,15 +251,14 @@ def count_zeros_fp32(parameters):
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
ops = []
|
||||
ops.append(dist.all_reduce(total_num_zeros,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.TENSOR),
|
||||
async_op=True))
|
||||
ops.append(
|
||||
dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
ops.append(dist.all_reduce(total_num_zeros,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE),
|
||||
async_op=True))
|
||||
ops.append(
|
||||
dist.all_reduce(total_num_zeros,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE),
|
||||
async_op=True))
|
||||
|
||||
for req in ops:
|
||||
req.wait()
|
||||
@@ -279,9 +275,8 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
|
||||
|
||||
|
||||
def param_is_not_tensor_parallel_duplicate(param):
|
||||
return (hasattr(param, IS_TENSOR_PARALLEL) and
|
||||
getattr(param, IS_TENSOR_PARALLEL)) or (
|
||||
gpc.get_local_rank(ParallelMode.TENSOR) == 0)
|
||||
return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
|
||||
ParallelMode.TENSOR) == 0)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
Reference in New Issue
Block a user