1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-09-18 16:00:49 +00:00
Files
.github
applications
colossalai
_C
_analyzer
accelerator
amp
auto_parallel
autochunk
booster
checkpoint_io
cli
cluster
context
device
fx
inference
interface
kernel
lazy
legacy
amp
builder
communication
context
engine
inference
nn
pipeline
registry
tensor
trainer
utils
zero
gemini
init_ctx
shard_utils
sharded_model
__init__.py
_utils.py
reduce_scatter.py
sharded_model_v2.py
utils.py
zero_hook.py
sharded_optim
sharded_param
__init__.py
__init__.py
constants.py
core.py
global_variables.py
initialize.py
logging
moe
nn
pipeline
shardformer
tensor
testing
utils
zero
__init__.py
initialize.py
docker
docs
examples
extensions
requirements
tests
.clang-format
.compatibility
.coveragerc
.cuda_ext.json
.gitignore
.gitmodules
.isort.cfg
.pre-commit-config.yaml
CHANGE_LOG.md
CONTRIBUTING.md
LICENSE
MANIFEST.in
README.md
pytest.ini
setup.py
version.txt
Hongxin Liu 079bf3cb26 [misc] update pre-commit and run all files ()
* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
2023-09-19 14:20:26 +08:00

21 lines
808 B
Python

import copy
import torch
from colossalai.legacy.zero.sharded_model import ShardedModelV2
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
"""
copy param of the ShardedModelV2 to other_model.
Note the other_model has to be the same as self.
"""
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, "colo_attr")
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])