mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,22 +2,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257):
|
||||
class GPTLMModel(nn.Module):
|
||||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257):
|
||||
super().__init__()
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
GPT2Config(
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
@@ -25,7 +23,6 @@ class GPTLMModel(nn.Module):
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
@@ -36,6 +33,7 @@ class GPTLMLoss(nn.Module):
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
def get_gpt2_components(model_type: str, batch_size: int):
|
||||
vocab_size = 1024
|
||||
seq_len = 8
|
||||
@@ -62,4 +60,4 @@ def get_gpt2_components(model_type: str, batch_size: int):
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return gpt2_model_builder, gpt2_data_gen
|
||||
return gpt2_model_builder, gpt2_data_gen
|
||||
|
@@ -1,2 +1,2 @@
|
||||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
||||
torch >= 1.8.1
|
||||
|
@@ -3,7 +3,6 @@ import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from model_zoo import GPTLMLoss, get_gpt2_components
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
import colossalai
|
||||
@@ -14,18 +13,19 @@ from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from model_zoo import GPTLMLoss, get_gpt2_components
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_type', type=str, default="gpt2_medium")
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument('--solver_type', type=str, default='asyn')
|
||||
parser.add_argument('--memory_budget', type=float, default=16)
|
||||
parser.add_argument("--model_type", type=str, default="gpt2_medium")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--solver_type", type=str, default="asyn")
|
||||
parser.add_argument("--memory_budget", type=float, default=16)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed")
|
||||
def train_gpt(args):
|
||||
memory_budget = args.memory_budget * 1024 * 1024 * 1024
|
||||
solver_type = args.solver_type
|
||||
@@ -34,10 +34,15 @@ def train_gpt(args):
|
||||
|
||||
# build model
|
||||
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
|
||||
label = torch.randint(low=0, high=128, size=(
|
||||
64,
|
||||
8,
|
||||
), device=get_current_device())
|
||||
label = torch.randint(
|
||||
low=0,
|
||||
high=128,
|
||||
size=(
|
||||
64,
|
||||
8,
|
||||
),
|
||||
device=get_current_device(),
|
||||
)
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
start_time = time.time()
|
||||
@@ -80,18 +85,20 @@ def train_gpt(args):
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5
|
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
|
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
|
||||
print(f'solver_type: {solver_type} | model_type: {model_type}')
|
||||
print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
|
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
|
||||
print(f"solver_type: {solver_type} | model_type: {model_type}")
|
||||
print(
|
||||
f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB "
|
||||
f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|"
|
||||
)
|
||||
print(time_list)
|
||||
|
||||
|
||||
def run(rank, world_size, port, args):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
train_gpt(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
spawn(run, 1, args=args)
|
||||
|
Reference in New Issue
Block a user