mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
.github
applications
colossalai
_C
_analyzer
accelerator
amp
auto_parallel
autochunk
booster
checkpoint_io
cli
cluster
context
device
fx
inference
interface
kernel
lazy
legacy
amp
builder
communication
context
engine
inference
nn
pipeline
registry
tensor
trainer
utils
zero
gemini
init_ctx
shard_utils
sharded_model
__init__.py
_utils.py
reduce_scatter.py
sharded_model_v2.py
utils.py
zero_hook.py
sharded_optim
sharded_param
__init__.py
__init__.py
constants.py
core.py
global_variables.py
initialize.py
logging
moe
nn
pipeline
shardformer
tensor
testing
utils
zero
__init__.py
initialize.py
docker
docs
examples
extensions
requirements
tests
.clang-format
.compatibility
.coveragerc
.cuda_ext.json
.gitignore
.gitmodules
.isort.cfg
.pre-commit-config.yaml
CHANGE_LOG.md
CONTRIBUTING.md
LICENSE
MANIFEST.in
README.md
pytest.ini
setup.py
version.txt
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
21 lines
808 B
Python
21 lines
808 B
Python
import copy
|
|
|
|
import torch
|
|
|
|
from colossalai.legacy.zero.sharded_model import ShardedModelV2
|
|
|
|
|
|
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
|
|
"""
|
|
copy param of the ShardedModelV2 to other_model.
|
|
Note the other_model has to be the same as self.
|
|
"""
|
|
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
|
assert hasattr(zero_param, "colo_attr")
|
|
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
|
|
if shard_flag:
|
|
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
|
|
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
|
|
if shard_flag:
|
|
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])
|