[example] update examples related to zero/gemini (#3431)

* [zero] update legacy import

* [zero] update examples

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix import
This commit is contained in:
ver217
2023-04-04 17:32:51 +08:00
committed by GitHub
parent 773955abfa
commit 573af84184
8 changed files with 50 additions and 6 deletions

View File

@@ -1,4 +1,9 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
clip_grad_norm = 1.0

View File

@@ -1,6 +1,11 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
# fp16 = dict(
# mode=AMP_TYPE.TORCH,
# )
@@ -29,4 +34,4 @@ optimizer = dict(
weight_decay=1e-2,
)
# 64433
# 64433

View File

@@ -1,4 +1,8 @@
from colossalai.zero.shard_utils import TensorShardStrategy
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
tensor_placement_policy="auto",

View File

@@ -4,3 +4,4 @@ datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
accelerate == 0.13.2
transformers

View File

@@ -413,7 +413,11 @@ def main():
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
try:
from colossalai.nn.parallel import GeminiDDP
except ImportError:
# this works for unreleased main branch, and this may be released on 0.2.9
from colossalai.zero import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager

View File

@@ -0,0 +1,21 @@
#!/bin/bash
set -xue
pip install -r requirements.txt
BS=8
MEMCAP=0
GPUNUM=2
MODLE="facebook/opt-125m"
torchrun \
--nproc_per_node ${GPUNUM} \
--master_port 19198 \
run_clm.py \
-s \
--output_dir $PWD \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE} \
--per_device_train_batch_size ${BS} \
--num_train_epochs 1

View File

@@ -0,0 +1,3 @@
#!/bin/bash
cd opt && bash test_ci.sh