[example] add TP to GPT example (#1828)

This commit is contained in:
Jiarui Fang
2022-11-08 17:17:19 +08:00
committed by GitHub
parent 49216d7ab1
commit a25f755331
4 changed files with 113 additions and 55 deletions

View File

@@ -36,7 +36,6 @@ from datasets import load_dataset
from packaging import version
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from utils import colo_memory_cap
import colossalai
import transformers
@@ -47,7 +46,6 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import (
@@ -249,12 +247,20 @@ def parse_args():
return args
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
def main():
args = parse_args()
disable_existing_loggers()
colossalai.launch_from_torch(config=dict())
logger = get_dist_logger()
is_main_process = gpc.get_local_rank(ParallelMode.DATA) == 0
is_main_process = dist.get_rank() == 0
if is_main_process:
datasets.utils.logging.set_verbosity_warning()

View File

@@ -1,28 +0,0 @@
import torch
import torch.distributed as dist
def memory_cap(size_in_GB):
print(f"use only {size_in_GB} GB of CUDA memory")
assert dist.is_initialized(), "memory_cap must be used after dist init"
local_rank = dist.get_rank()
cuda_capacity = torch.cuda.get_device_properties(local_rank).total_memory
size_in_B = (size_in_GB * 1024**3)
if size_in_B > cuda_capacity:
print(f'memory_cap is uselsess since {cuda_capacity / 1024**3} less than {size_in_GB}')
return
fraction = (size_in_GB * 1024**3) / cuda_capacity
print(f'mem faction is {fraction}')
torch.cuda.set_per_process_memory_fraction(fraction, local_rank)
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
if __name__ == '__main__':
memory_cap(40)