mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
* [branch rebase] rebase main to Feature/resize_embedding (#5554) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [CI] run pre-commit (#5577) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme * run pre-commit --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [rebase] rebase main to resize-embedding (#5581) * [release] grok-1 314b inference (#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference * [example] update Grok-1 inference (#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url * [hotfix] set return_outputs=False in examples and polish code (#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value * [release] grok-1 inference benchmark (#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [shardformer]Fix lm parallel. (#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix * [fix] fix grok-1 example typo (#5506) * [devops] fix example test ci (#5504) * Fix ColoTensorSpec for py11 (#5440) * fixed layout converter caching and updated tester * Empty-Commit * [shardformer] update colo attention to support custom mask (#5510) * [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests * [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions <github-actions@github.com> * [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [Fix] Grok-1 use tokenizer from the same pretrained path (#5532) * [fix] use tokenizer from the same pretrained path * trust remote code * [ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com> * [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests * fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * [hotfix] quick fixes to make legacy tutorials runnable (#5559) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [fix] fix typo s/muiti-node /multi-node etc. (#5448) * [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548) * [devops] remove post commit ci (#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]enable padding vocabulary size. (#5489) * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * padding vocab * padding vocabe * fix * fix * fxi * test ci * fix fix fix fix * fix fix * fix * fix * Update hybrid_parallel_plugin.py fix fix fix * fix fix * fix fix * fix * resolve super init resolve super init resolve super init resolve super init * resolve comments * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vocab checkpointio * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix fix fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * padding vocab * fix * fix fix * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * cherry-pick * revert moe modify * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix fix fix fix fix fix fix fix * resolve comments resolve comments resolve comments resolve comments resolve comments * ptensor ptensor resolve comments fix fix fix fix fix resolve comments resolve comments resolve comments resolve comments resolve comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rebase * fix rebase --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
362 lines
17 KiB
Python
362 lines
17 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import itertools
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed import ProcessGroup
|
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
|
|
|
|
from colossalai.checkpoint_io.utils import gather_distributed_param
|
|
from colossalai.tensor.d_tensor import (
|
|
distribute_tensor,
|
|
distribute_tensor_with_customization,
|
|
get_device_mesh,
|
|
get_sharding_spec,
|
|
is_customized_distributed_tensor,
|
|
is_distributed_tensor,
|
|
sharded_tensor_to_param,
|
|
)
|
|
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
|
|
|
|
__all__ = ["ParallelModule"]
|
|
|
|
|
|
class ParallelModule(nn.Module, ABC):
|
|
def __init__(self, **kwargs):
|
|
super().__init__()
|
|
|
|
@abstractmethod
|
|
def from_native_module(
|
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
|
|
) -> "ParallelModule":
|
|
"""
|
|
Convert a native PyTorch module to a parallelized module.
|
|
|
|
Args:
|
|
module (nn.Module): the module to be converted.
|
|
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
|
|
If this is a list, the process group at the ith index of the list will correspond to the process group
|
|
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
|
"""
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
r"""Saves module state to `destination` dictionary, containing a state
|
|
of the module, but not its descendants. This is called on every
|
|
submodule in :meth:`~torch.nn.Module.state_dict`.
|
|
|
|
In rare cases, subclasses can achieve class-specific behavior by
|
|
overriding this method with custom logic.
|
|
|
|
Args:
|
|
destination (dict): a dict where state will be stored
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
"""
|
|
for name, param in self._parameters.items():
|
|
if param is not None:
|
|
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data
|
|
|
|
for name, buf in self._buffers.items():
|
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
|
|
destination[extra_state_key] = self.get_extra_state()
|
|
|
|
def _load_from_state_dict(
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
|
this module, but not its descendants. This is called on every submodule
|
|
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
|
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
|
For state dicts without metadata, :attr:`local_metadata` is empty.
|
|
Subclasses can achieve class-specific backward compatible loading using
|
|
the version number at `local_metadata.get("version", None)`.
|
|
|
|
.. note::
|
|
:attr:`state_dict` is not the same object as the input
|
|
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
|
it can be modified.
|
|
|
|
Args:
|
|
state_dict (dict): a dict containing parameters and
|
|
persistent buffers.
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
local_metadata (dict): a dict containing the metadata for this module.
|
|
See
|
|
strict (bool): whether to strictly enforce that the keys in
|
|
:attr:`state_dict` with :attr:`prefix` match the names of
|
|
parameters and buffers in this module
|
|
missing_keys (list of str): if ``strict=True``, add missing keys to
|
|
this list
|
|
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
|
keys to this list
|
|
error_msgs (list of str): error messages should be added to this
|
|
list, and will be reported together in
|
|
:meth:`~torch.nn.Module.load_state_dict`
|
|
"""
|
|
for hook in self._load_state_dict_pre_hooks.values():
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
|
|
for name, param in local_state.items():
|
|
key = prefix + name
|
|
|
|
if key in state_dict:
|
|
input_param = state_dict[key]
|
|
if not torch.overrides.is_tensor_like(input_param):
|
|
error_msgs.append(
|
|
'While copying the parameter named "{}", '
|
|
"expected torch.Tensor or Tensor-like object from checkpoint but "
|
|
"received {}".format(key, type(input_param))
|
|
)
|
|
continue
|
|
|
|
if is_distributed_tensor(param):
|
|
# shard the input param
|
|
device_mesh = get_device_mesh(param)
|
|
sharding_spec = get_sharding_spec(param)
|
|
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
|
|
input_param = sharded_tensor_to_param(sharded_tensor)
|
|
elif is_customized_distributed_tensor(param):
|
|
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
|
|
|
|
# This is used to avoid copying uninitialized parameters into
|
|
# non-lazy modules, since they dont have the hook to do the checks
|
|
# in such case, it will error when accessing the .shape attribute.
|
|
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
|
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
|
input_param = input_param[0]
|
|
|
|
if not is_param_lazy and input_param.shape != param.shape:
|
|
# local shape should match the one in checkpoint
|
|
error_msgs.append(
|
|
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
|
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
|
|
)
|
|
continue
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
param.copy_(input_param)
|
|
except Exception as ex:
|
|
error_msgs.append(
|
|
'While copying the parameter named "{}", '
|
|
"whose dimensions in the model are {} and "
|
|
"whose dimensions in the checkpoint are {}, "
|
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
|
)
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
|
if extra_state_key in state_dict:
|
|
self.set_extra_state(state_dict[extra_state_key])
|
|
elif strict:
|
|
missing_keys.append(extra_state_key)
|
|
elif strict and (extra_state_key in state_dict):
|
|
unexpected_keys.append(extra_state_key)
|
|
|
|
if strict:
|
|
for key in state_dict.keys():
|
|
if key.startswith(prefix) and key != extra_state_key:
|
|
input_name = key[len(prefix) :]
|
|
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
|
if input_name not in self._modules and input_name not in local_state:
|
|
unexpected_keys.append(key)
|
|
|
|
|
|
class PaddingParallelModule(ParallelModule):
|
|
def __init__(
|
|
self,
|
|
new_num_embeddings: int,
|
|
old_num_embeddings: int,
|
|
weight: Optional[nn.Parameter],
|
|
bias_: Optional[nn.Parameter] = None,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self.new_num_embeddings = new_num_embeddings
|
|
self.old_num_embeddings = old_num_embeddings
|
|
self.weight = weight
|
|
self.bias = bias_
|
|
|
|
if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
|
|
self.resize_embedding_weight()
|
|
|
|
if self.bias is not None and not (
|
|
is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings
|
|
):
|
|
self.resize_embedding_bias()
|
|
|
|
@abstractmethod
|
|
def from_native_module(
|
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
|
|
) -> "PaddingParallelModule":
|
|
"""
|
|
Convert a native PyTorch module to a parallelized module.
|
|
|
|
Args:
|
|
module (nn.Module): the module to be converted.
|
|
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
|
|
If this is a list, the process group at the ith index of the list will correspond to the process group
|
|
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
r"""Saves module state to `destination` dictionary, containing a state
|
|
of the module, but not its descendants. This is called on every
|
|
submodule in :meth:`~torch.nn.Module.state_dict`.
|
|
|
|
In rare cases, subclasses can achieve class-specific behavior by
|
|
overriding this method with custom logic.
|
|
|
|
Args:
|
|
destination (dict): a dict where state will be stored
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
"""
|
|
for name, param in self._parameters.items():
|
|
if param is not None:
|
|
param = gather_distributed_param(param, keep_vars=keep_vars)
|
|
if is_padded_tensor(param):
|
|
param = to_unpadded_tensor(param)
|
|
destination[prefix + name] = param.data
|
|
|
|
for name, buf in self._buffers.items():
|
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
|
|
destination[extra_state_key] = self.get_extra_state()
|
|
|
|
def _load_from_state_dict(
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
|
this module, but not its descendants. This is called on every submodule
|
|
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
|
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
|
For state dicts without metadata, :attr:`local_metadata` is empty.
|
|
Subclasses can achieve class-specific backward compatible loading using
|
|
the version number at `local_metadata.get("version", None)`.
|
|
|
|
.. note::
|
|
:attr:`state_dict` is not the same object as the input
|
|
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
|
it can be modified.
|
|
|
|
Args:
|
|
state_dict (dict): a dict containing parameters and
|
|
persistent buffers.
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
local_metadata (dict): a dict containing the metadata for this module.
|
|
See
|
|
strict (bool): whether to strictly enforce that the keys in
|
|
:attr:`state_dict` with :attr:`prefix` match the names of
|
|
parameters and buffers in this module
|
|
missing_keys (list of str): if ``strict=True``, add missing keys to
|
|
this list
|
|
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
|
keys to this list
|
|
error_msgs (list of str): error messages should be added to this
|
|
list, and will be reported together in
|
|
:meth:`~torch.nn.Module.load_state_dict`
|
|
"""
|
|
for hook in self._load_state_dict_pre_hooks.values():
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
|
|
for name, param in local_state.items():
|
|
key = prefix + name
|
|
|
|
if key in state_dict:
|
|
input_param = state_dict[key]
|
|
if not torch.overrides.is_tensor_like(input_param):
|
|
error_msgs.append(
|
|
'While copying the parameter named "{}", '
|
|
"expected torch.Tensor or Tensor-like object from checkpoint but "
|
|
"received {}".format(key, type(input_param))
|
|
)
|
|
continue
|
|
|
|
if is_padded_tensor(param):
|
|
input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)
|
|
|
|
if is_distributed_tensor(param):
|
|
# shard the input param
|
|
device_mesh = get_device_mesh(param)
|
|
sharding_spec = get_sharding_spec(param)
|
|
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
|
|
input_param = sharded_tensor_to_param(sharded_tensor)
|
|
elif is_customized_distributed_tensor(param):
|
|
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
|
|
|
|
# This is used to avoid copying uninitialized parameters into
|
|
# non-lazy modules, since they dont have the hook to do the checks
|
|
# in such case, it will error when accessing the .shape attribute.
|
|
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
|
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
|
input_param = input_param[0]
|
|
|
|
if not is_param_lazy and input_param.shape != param.shape:
|
|
# local shape should match the one in checkpoint
|
|
error_msgs.append(
|
|
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
|
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
|
|
)
|
|
continue
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
param.copy_(input_param)
|
|
except Exception as ex:
|
|
error_msgs.append(
|
|
'While copying the parameter named "{}", '
|
|
"whose dimensions in the model are {} and "
|
|
"whose dimensions in the checkpoint are {}, "
|
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
|
)
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
|
if extra_state_key in state_dict:
|
|
self.set_extra_state(state_dict[extra_state_key])
|
|
elif strict:
|
|
missing_keys.append(extra_state_key)
|
|
elif strict and (extra_state_key in state_dict):
|
|
unexpected_keys.append(extra_state_key)
|
|
|
|
if strict:
|
|
for key in state_dict.keys():
|
|
if key.startswith(prefix) and key != extra_state_key:
|
|
input_name = key[len(prefix) :]
|
|
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
|
if input_name not in self._modules and input_name not in local_state:
|
|
unexpected_keys.append(key)
|
|
|
|
def resize_embedding_weight(self):
|
|
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
|
|
|
|
def resize_embedding_bias(self):
|
|
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)
|