mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-16 07:03:40 +00:00
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
265 lines
9.3 KiB
Python
265 lines
9.3 KiB
Python
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch._six import inf
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
|
|
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.tensor import ProcessGroup
|
|
from colossalai.utils import is_model_parallel_parameter
|
|
|
|
|
|
def flatten(input_):
|
|
return _flatten_dense_tensors(input_)
|
|
|
|
|
|
def unflatten(flat, tensors):
|
|
return _unflatten_dense_tensors(flat, tensors)
|
|
|
|
|
|
def count_numel(tensor_list):
|
|
res = 0
|
|
for tensor in tensor_list:
|
|
res += tensor.numel()
|
|
return res
|
|
|
|
|
|
def calculate_padding(numel, unit_size):
|
|
remainder = numel % unit_size
|
|
return unit_size - remainder if remainder else remainder
|
|
|
|
|
|
def shuffle_by_round_robin(tensor_list, num_partitions):
|
|
partitions = dict()
|
|
|
|
for tensor_idx, tensor in enumerate(tensor_list):
|
|
partition_to_go = tensor_idx % num_partitions
|
|
if partition_to_go not in partitions:
|
|
partitions[partition_to_go] = []
|
|
partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx))
|
|
|
|
partitions_count = len(partitions)
|
|
new_tensor_list = []
|
|
tensor_index_mapping = dict()
|
|
|
|
for partition_id in range(partitions_count):
|
|
partition_tensors = partitions[partition_id]
|
|
for item in partition_tensors:
|
|
tensor_index_mapping[item['index']] = len(new_tensor_list)
|
|
new_tensor_list.append(item['tensor'])
|
|
|
|
return new_tensor_list, tensor_index_mapping
|
|
|
|
|
|
# create a flat tensor aligned at the alignment boundary
|
|
def flatten_dense_tensors_with_padding(tensor_list, unit_size):
|
|
num_elements = count_numel(tensor_list)
|
|
padding = calculate_padding(num_elements, unit_size=unit_size)
|
|
|
|
if padding > 0:
|
|
pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
|
|
padded_tensor_list = tensor_list + [pad_tensor]
|
|
else:
|
|
padded_tensor_list = tensor_list
|
|
|
|
return flatten(padded_tensor_list)
|
|
|
|
|
|
def is_nccl_aligned(tensor):
|
|
return tensor.data_ptr() % 4 == 0
|
|
|
|
|
|
def get_grad_accumulate_object(tensor):
|
|
"""
|
|
Return the AccumulateGrad of the input tensor
|
|
"""
|
|
|
|
# grad_fn reference:
|
|
# https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463
|
|
# expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand
|
|
#
|
|
# `next_functions` will return the backward graph where
|
|
# the first element is the AccumulateGrad of the leaf nodes.
|
|
# we want to get the AccumulateGrad of the input tensor instead of the leaf
|
|
# node in the whole computation graph.
|
|
# Therefore, we call expand_as to create a dummy graph
|
|
# where tensor_tmp and tensor indeed point to the same object.
|
|
# You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr())
|
|
tensor_tmp = tensor.expand_as(tensor)
|
|
grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0]
|
|
return grad_acc_obj
|
|
|
|
|
|
def split_half_float_double(tensor_list):
|
|
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
|
buckets = []
|
|
for i, dtype in enumerate(dtypes):
|
|
bucket = [t for t in tensor_list if t.type() == dtype]
|
|
if bucket:
|
|
buckets.append(bucket)
|
|
return buckets
|
|
|
|
|
|
def reduce_tensor_dp_group(tensor: torch.Tensor,
|
|
dtype: Optional[torch.dtype] = None,
|
|
dst_local_rank: Optional[int] = None,
|
|
dst_global_rank: Optional[int] = None,
|
|
group: Optional[dist.ProcessGroup] = None):
|
|
"""
|
|
Reduce the tensor in the data parallel process group
|
|
|
|
:param tensor: A tensor object to reduce/all-reduce
|
|
:param dtype: The data type used in communication
|
|
:param dst_rank: The source rank for reduce. If dst_rank is None,
|
|
:param parallel_mode: Communication parallel mode
|
|
all-reduce will be used instead of reduce. Default is None.
|
|
|
|
:type tensor: torch.Tensor
|
|
:type dtype: torch.dtype, optional
|
|
:type dst_rank: int, optional
|
|
:type pg: ProcessGroup, optional
|
|
"""
|
|
# use the original dtype
|
|
if dtype is None:
|
|
dtype = tensor.dtype
|
|
|
|
# cast the data to specified dtype for reduce/all-reduce
|
|
if tensor.dtype != dtype:
|
|
tensor_to_reduce = tensor.to(dtype)
|
|
else:
|
|
tensor_to_reduce = tensor
|
|
|
|
world_size = dist.get_world_size(group=group)
|
|
tensor_to_reduce.div_(world_size)
|
|
|
|
# if rank is None, all reduce will be used
|
|
# else, reduce is used
|
|
use_all_reduce = dst_local_rank is None
|
|
|
|
if use_all_reduce:
|
|
dist.all_reduce(tensor_to_reduce, group=group)
|
|
else:
|
|
dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)
|
|
|
|
# recover the original dtype
|
|
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
|
local_rank = dist.get_rank(group=group)
|
|
if use_all_reduce or dst_local_rank == local_rank:
|
|
tensor.copy_(tensor_to_reduce)
|
|
|
|
return tensor
|
|
|
|
|
|
def has_inf_or_nan(tensor):
|
|
try:
|
|
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
|
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
|
# (which is true for some recent version of pytorch).
|
|
tensor_sum = float(tensor.float().sum())
|
|
# More efficient version that can be used if .sum() returns a Python scalar
|
|
# tensor_sum = float(tensor.sum())
|
|
except RuntimeError as instance:
|
|
# We want to check if inst is actually an overflow exception.
|
|
# RuntimeError could come from a different error.
|
|
# If so, we still want the exception to propagate.
|
|
if "value cannot be converted" not in instance.args[0]:
|
|
raise
|
|
return True
|
|
else:
|
|
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
|
|
return True
|
|
return False
|
|
|
|
|
|
def release_param_grad(tensor_list):
|
|
for tensor in tensor_list:
|
|
tensor.grad = None
|
|
|
|
|
|
def calculate_global_norm_from_list(norm_list):
|
|
""" Compute total from a list of norms
|
|
"""
|
|
total_norm = 0.0
|
|
for norm in norm_list:
|
|
total_norm += norm**2.0
|
|
return math.sqrt(total_norm)
|
|
|
|
|
|
def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
|
"""Clips gradient norm of an iterable of parameters.
|
|
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
|
added functionality to handle model parallel parameters. Note that
|
|
the gradients are modified in place.
|
|
Arguments:
|
|
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
|
single Tensor that will have gradients normalized
|
|
max_norm (float or int): max norm of the gradients
|
|
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
|
infinity norm.
|
|
Returns:
|
|
Total norm of the parameters (viewed as a single vector).
|
|
"""
|
|
|
|
if mp_group is None:
|
|
mp_rank = 0
|
|
else:
|
|
mp_rank = dist.get_rank(mp_group)
|
|
|
|
norm_type = float(norm_type)
|
|
if norm_type == inf:
|
|
total_norm = max(g.data.abs().max() for g in gradients)
|
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
|
|
|
# Take max across all GPUs.
|
|
if mp_group is not None:
|
|
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
|
total_norm = total_norm_cuda[0].item()
|
|
else:
|
|
total_norm = 0.0
|
|
# if dist.get_rank() == 0:
|
|
# logger.info(f"Total Norm beginning {total_norm}")
|
|
|
|
for g, p in zip(gradients, params):
|
|
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
|
|
if is_model_parallel_parameter(p) or mp_rank == 0:
|
|
param_norm = g.data.double().norm(2)
|
|
total_norm += param_norm.item()**2
|
|
|
|
# Sum across all model parallel GPUs.
|
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
|
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
|
|
|
if mp_group is not None:
|
|
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
|
|
|
|
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
|
|
|
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
|
|
total_norm = -1
|
|
|
|
return total_norm
|
|
|
|
|
|
def sync_param(flat_tensor, tensor_list):
|
|
"""
|
|
Synchronize the flattened tensor and unflattened tensor list. When
|
|
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
|
|
a new tensor is created. Thus, the flat tensor and original tensor list do not
|
|
share the same memory space. This function will update the tensor list so that
|
|
they point to the same value.
|
|
|
|
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit
|
|
:param tensor_list: A list of tensors corresponding to the flattened tensor
|
|
:type flat_tensor: torch.Tensor
|
|
:type tensor_list: List[torch.Tensor]
|
|
"""
|
|
updated_params = unflatten(flat_tensor, tensor_list)
|
|
|
|
# update the tensor data
|
|
for p, q in zip(tensor_list, updated_params):
|
|
p.data = q.data
|