[Gemini] make gemini usage simple (#1821)

This commit is contained in:
Jiarui Fang
2022-11-08 15:53:13 +08:00
committed by GitHub
parent 99870726b1
commit cd5a0d56fa
4 changed files with 49 additions and 21 deletions

View File

@@ -24,7 +24,6 @@ https://huggingface.co/models?filter=text-generation
import math
import os
import random
import time
from itertools import chain
@@ -43,7 +42,6 @@ import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
@@ -380,11 +378,8 @@ def main():
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.gemini import GeminiManager
from colossalai.gemini.chunk import init_chunk_manager
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=32)
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
pg = ProcessGroup()
@@ -393,6 +388,8 @@ def main():
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager)
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])