mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [CI] run pre-commit (#5577) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme * run pre-commit --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [rebase] rebase main to resize-embedding (#5581) * [release] grok-1 314b inference (#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference * [example] update Grok-1 inference (#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url * [hotfix] set return_outputs=False in examples and polish code (#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value * [release] grok-1 inference benchmark (#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [shardformer]Fix lm parallel. (#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix * [fix] fix grok-1 example typo (#5506) * [devops] fix example test ci (#5504) * Fix ColoTensorSpec for py11 (#5440) * fixed layout converter caching and updated tester * Empty-Commit * [shardformer] update colo attention to support custom mask (#5510) * [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests * [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions <github-actions@github.com> * [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [Fix] Grok-1 use tokenizer from the same pretrained path (#5532) * [fix] use tokenizer from the same pretrained path * trust remote code * [ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com> * [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests * fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * [hotfix] quick fixes to make legacy tutorials runnable (#5559) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [fix] fix typo s/muiti-node /multi-node etc. (#5448) * [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548) * [devops] remove post commit ci (#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]enable padding vocabulary size. (#5489) * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * padding vocab * padding vocabe * fix * fix * fxi * test ci * fix fix fix fix * fix fix * fix * fix * Update hybrid_parallel_plugin.py fix fix fix * fix fix * fix fix * fix * resolve super init resolve super init resolve super init resolve super init * resolve comments * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vocab checkpointio * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix fix fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * padding vocab * fix * fix fix * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * cherry-pick * revert moe modify * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix fix fix fix fix fix fix fix * resolve comments resolve comments resolve comments resolve comments resolve comments * ptensor ptensor resolve comments fix fix fix fix fix resolve comments resolve comments resolve comments resolve comments resolve comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rebase * fix rebase --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,12 @@ from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.tensor.padded_tensor import (
|
||||
init_as_padded_tensor,
|
||||
is_padded_tensor,
|
||||
to_padded_tensor,
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
|
||||
|
||||
@@ -460,6 +466,11 @@ class GeminiDDP(ModelWrapper):
|
||||
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
|
||||
)
|
||||
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
|
||||
if is_padded_tensor(tensor):
|
||||
record_tensor = init_as_padded_tensor(
|
||||
record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
|
||||
)
|
||||
record_tensor = to_unpadded_tensor(record_tensor)
|
||||
|
||||
assert tensor not in chunk_to_save_data
|
||||
chunk_to_save_data[tensor] = record_tensor
|
||||
@@ -520,6 +531,8 @@ class GeminiDDP(ModelWrapper):
|
||||
# deal with ddp ignored parameters
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
else:
|
||||
if is_padded_tensor(p_mapping[param]):
|
||||
p_mapping[param] = to_unpadded_tensor(p_mapping[param])
|
||||
destination[prefix + name] = p_mapping[param]
|
||||
del p_mapping
|
||||
del param_to_save_data
|
||||
@@ -627,6 +640,7 @@ class GeminiDDP(ModelWrapper):
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
@@ -647,6 +661,14 @@ class GeminiDDP(ModelWrapper):
|
||||
if state_key in state_dict:
|
||||
input_param = state_dict[state_key]
|
||||
|
||||
global_shape = dest_tensor.shape
|
||||
if source_device_mesh is not None and source_sharding_spec is not None:
|
||||
global_shape = get_global_shape(dest_tensor)
|
||||
|
||||
if is_padded_tensor(dest_tensor):
|
||||
padding_dim = dest_tensor._padding_dim
|
||||
input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim)
|
||||
|
||||
if source_device_mesh is not None and source_sharding_spec is not None:
|
||||
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
|
||||
elif shard_fn is not None and gather_fn is not None:
|
||||
|
@@ -21,12 +21,19 @@ from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
get_device_mesh,
|
||||
get_global_shape,
|
||||
get_sharding_spec,
|
||||
init_as_dtensor,
|
||||
init_tensor_as_customization_distributed,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.tensor.padded_tensor import (
|
||||
init_as_padded_tensor,
|
||||
is_padded_tensor,
|
||||
to_padded_tensor,
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.utils import disposable, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
@@ -106,7 +113,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
tp_group: ProcessGroup = None,
|
||||
optimizer_params_info=None,
|
||||
params_info=None,
|
||||
verbose: bool = False,
|
||||
**defaults: Any,
|
||||
):
|
||||
@@ -124,7 +131,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self.clipping_flag = max_norm > 0.0
|
||||
self.max_norm = max_norm
|
||||
self.tp_group = tp_group
|
||||
self.optimizer_params_info = optimizer_params_info
|
||||
self.params_info = params_info
|
||||
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
|
||||
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
|
||||
self.verbose = verbose
|
||||
@@ -459,7 +466,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
is_customized_distributed = is_customized_distributed_tensor(param)
|
||||
shard_spec = get_sharding_spec(param) if is_dtensor else None
|
||||
device_mesh = get_device_mesh(param) if is_dtensor else None
|
||||
global_shape = self.optimizer_params_info["id2shape"][param_id]
|
||||
global_shape = self.params_info["id2shape"][param_id]
|
||||
|
||||
# If the chunk is kept gathered,
|
||||
# the parameters are treated the same as that of those in strict DDP during training.
|
||||
@@ -477,6 +484,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
else:
|
||||
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
if is_dtensor:
|
||||
global_shape = get_global_shape(param)
|
||||
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
||||
state_tensor = init_as_dtensor(
|
||||
state_tensor,
|
||||
@@ -490,8 +498,13 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
|
||||
)
|
||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||
|
||||
collected_states[state_name] = state_tensor.reshape(global_shape)
|
||||
state_tensor = state_tensor.reshape(global_shape)
|
||||
if is_padded_tensor(param):
|
||||
state_tensor = init_as_padded_tensor(
|
||||
state_tensor, param._current_length, param._origin_length, param._padding_dim
|
||||
)
|
||||
state_tensor = to_unpadded_tensor(state_tensor)
|
||||
collected_states[state_name] = state_tensor
|
||||
return collected_states
|
||||
|
||||
# Check whether the param with given id is managed by current process.
|
||||
@@ -535,6 +548,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
if state_tensor.numel() == param.numel():
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
if is_dtensor:
|
||||
global_shape = get_global_shape(param)
|
||||
state_tensor = state_tensor.to(param.device)
|
||||
state_tensor = init_as_dtensor(
|
||||
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
|
||||
@@ -545,6 +559,11 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
|
||||
)
|
||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||
if is_padded_tensor(param):
|
||||
state_tensor = init_as_padded_tensor(
|
||||
state_tensor, param._current_length, param._origin_length, param._padding_dim
|
||||
)
|
||||
state_tensor = to_unpadded_tensor(state_tensor)
|
||||
|
||||
return collected_states
|
||||
|
||||
@@ -698,7 +717,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
Load saved optimizer states into parameter with given id.
|
||||
"""
|
||||
|
||||
def cast(param, state_range, value, key=None):
|
||||
def cast(param, state_range, value, global_shape, origin_shape, key=None):
|
||||
"""
|
||||
Make a copy of the needed segment of value and cast it to device of param.
|
||||
"""
|
||||
@@ -714,7 +733,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
)
|
||||
|
||||
if is_dtensor:
|
||||
value = torch.reshape(value, global_shape)
|
||||
global_shape = get_global_shape(real_param)
|
||||
|
||||
if is_padded_tensor(real_param):
|
||||
value = torch.reshape(value, origin_shape)
|
||||
padding_dim = real_param._padding_dim
|
||||
value = to_padded_tensor(value, global_shape[padding_dim], padding_dim)
|
||||
|
||||
if is_dtensor:
|
||||
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
|
||||
elif is_customized_distributed:
|
||||
value = torch.reshape(value, global_shape)
|
||||
@@ -737,10 +763,11 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
is_customized_distributed = is_customized_distributed_tensor(real_param)
|
||||
shard_spec = get_sharding_spec(real_param) if is_dtensor else None
|
||||
device_mesh = get_device_mesh(real_param) if is_dtensor else None
|
||||
global_shape = self.optimizer_params_info["id2shape"][param_id]
|
||||
global_shape = self.params_info["id2shape"][param_id]
|
||||
origin_shape = global_shape
|
||||
|
||||
for k, v in saved_states.items():
|
||||
updated_states[k] = cast(fake_param, state_range, v, k)
|
||||
updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k)
|
||||
del v # clean loaded states
|
||||
self.optim.state[fake_param].update(updated_states)
|
||||
|
||||
|
Reference in New Issue
Block a user