mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[Gemini] make gemini usage simple (#1821)
This commit is contained in:
@@ -24,7 +24,6 @@ https://huggingface.co/models?filter=text-generation
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
@@ -43,7 +42,6 @@ import colossalai
|
||||
import transformers
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
@@ -380,11 +378,8 @@ def main():
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.gemini import GeminiManager
|
||||
from colossalai.gemini.chunk import init_chunk_manager
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=32)
|
||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
pg = ProcessGroup()
|
||||
@@ -393,6 +388,8 @@ def main():
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
|
||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||
|
||||
|
Reference in New Issue
Block a user