mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[zero] reorganize zero/gemini folder structure (#3424)
* [zero] refactor low-level zero folder structure * [zero] fix legacy zero import path * [zero] fix legacy zero import path * [zero] remove useless import * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] fix test import path * [zero] fix test * [zero] fix circular import * [zero] update import
This commit is contained in:
@@ -34,12 +34,9 @@ from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
|
||||
|
||||
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
@@ -179,13 +176,15 @@ def main():
|
||||
# build model
|
||||
if args.model_name_or_path is None:
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
with ColoInitContext(device=init_dev, dtype=torch.half,
|
||||
with ColoInitContext(device=init_dev,
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ColoInitContext(device=init_dev, dtype=torch.half,
|
||||
with ColoInitContext(device=init_dev,
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
@@ -198,8 +197,11 @@ def main():
|
||||
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
PLACEMENT_POLICY = 'cpu'
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
|
||||
pin_memory=True, strict_ddp_mode=args.shardinit)
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=PLACEMENT_POLICY,
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=args.shardinit)
|
||||
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
|
||||
|
||||
SEQ_LEN = 1024
|
||||
|
Reference in New Issue
Block a user