mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-22 12:09:30 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,33 +14,40 @@ class DummyDataLoader(DummyDataGenerator):
|
||||
seq_len = 64
|
||||
|
||||
def generate(self):
|
||||
input_ids = torch.randint(0,
|
||||
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
|
||||
device=get_current_device())
|
||||
input_ids = torch.randint(
|
||||
0,
|
||||
DummyDataLoader.vocab_size,
|
||||
(DummyDataLoader.batch_size, DummyDataLoader.seq_len),
|
||||
device=get_current_device(),
|
||||
)
|
||||
return input_ids, input_ids
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50304,
|
||||
checkpoint=False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50304,
|
||||
checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0))
|
||||
GPT2Config(
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0,
|
||||
)
|
||||
)
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
@@ -51,12 +58,9 @@ class GPTLMModel(nn.Module):
|
||||
|
||||
|
||||
def gpt2_micro(checkpoint=True):
|
||||
return GPTLMModel(checkpoint=checkpoint,
|
||||
hidden_size=32,
|
||||
num_layers=2,
|
||||
num_attention_heads=4,
|
||||
max_seq_len=64,
|
||||
vocab_size=128)
|
||||
return GPTLMModel(
|
||||
checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128
|
||||
)
|
||||
|
||||
|
||||
def gpt2_s(checkpoint=True):
|
||||
@@ -68,7 +72,6 @@ def gpt2_m(checkpoint=True):
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
@@ -80,9 +83,8 @@ class GPTLMLoss(nn.Module):
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='gpt2')
|
||||
@non_distributed_component_funcs.register(name="gpt2")
|
||||
def get_training_components():
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user