mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[auto-parallel] add auto-offload feature (#3154)
* add auto-offload feature * polish code * fix syn offload runtime pass bug * add offload example * fix offload testing bug * fix example testing bug
This commit is contained in:
86
tests/test_auto_parallel/test_offload/model_utils.py
Normal file
86
tests/test_auto_parallel/test_offload/model_utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
from transformers import BertConfig, BertLMHeadModel
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257):
|
||||
super().__init__()
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
|
||||
|
||||
|
||||
class LMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
class BertLMModel(nn.Module):
|
||||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522):
|
||||
super().__init__()
|
||||
self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size,
|
||||
vocab_size=vocab_size))
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
|
||||
|
||||
@non_distributed_component_funcs.register(name='bert_')
|
||||
def get_bert_components():
|
||||
vocab_size = 1024
|
||||
seq_len = 64
|
||||
batchSize = 64
|
||||
|
||||
def bert_model_builder():
|
||||
model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size)
|
||||
return model
|
||||
|
||||
def bert_data_gen(device="meta"):
|
||||
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
|
||||
attention_mask = torch.ones_like(input_ids, device=device)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return bert_model_builder, bert_data_gen
|
||||
|
||||
@non_distributed_component_funcs.register(name='gpt2_')
|
||||
def get_gpt2_components():
|
||||
vocab_size = 1024
|
||||
seq_len = 8
|
||||
batchSize = 64
|
||||
|
||||
def gpt2_model_builder():
|
||||
model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size)
|
||||
return model
|
||||
|
||||
def gpt2_data_gen(device="meta"):
|
||||
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
|
||||
attention_mask = torch.ones_like(input_ids, device=device)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return gpt2_model_builder, gpt2_data_gen
|
Reference in New Issue
Block a user