[zero] reorganize zero/gemini folder structure (#3424)

* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
This commit is contained in:
ver217
2023-04-04 13:48:16 +08:00
committed by GitHub
parent b09adff724
commit 26b7aac0be
142 changed files with 1435 additions and 1404 deletions

View File

@@ -34,12 +34,9 @@ from transformers.utils.versions import require_version
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
def get_data(batch_size, seq_len, vocab_size):
@@ -179,13 +176,15 @@ def main():
# build model
if args.model_name_or_path is None:
logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half,
with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half,
with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
@@ -198,8 +197,11 @@ def main():
numel = sum([p.numel() for p in model.parameters()])
PLACEMENT_POLICY = 'cpu'
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
pin_memory=True, strict_ddp_mode=args.shardinit)
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=PLACEMENT_POLICY,
pin_memory=True,
strict_ddp_mode=args.shardinit)
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
SEQ_LEN = 1024