mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[Gemini] add unitests to check gemini correctness (#2015)
This commit is contained in:
@@ -1 +1,2 @@
|
||||
from . import bert, gpt, inline_op_model, nested_model, no_leaf_module, repeated_computed_layer, resnet, simple_net
|
||||
from .utils import run_fwd_bwd
|
||||
|
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .registry import non_distributed_component_funcs
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
from .utils.dummy_data_generator import DummyDataGenerator
|
||||
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from .registry import non_distributed_component_funcs
|
||||
from .utils.dummy_data_generator import DummyDataGenerator
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
vocab_size = 128
|
||||
@@ -15,8 +17,7 @@ class DummyDataLoader(DummyDataGenerator):
|
||||
input_ids = torch.randint(0,
|
||||
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
|
||||
device=get_current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
return input_ids, input_ids
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
@@ -43,8 +44,9 @@ class GPTLMModel(nn.Module):
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
def forward(self, input_ids):
|
||||
# Only return lm_logits
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
|
@@ -38,7 +38,7 @@ class DummyDataLoader(DummyDataGenerator):
|
||||
return data, label
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='inline_op_module')
|
||||
@non_distributed_component_funcs.register(name='inline_op_model')
|
||||
def get_training_components():
|
||||
|
||||
def model_builder(checkpoint=True):
|
||||
|
@@ -1 +1,2 @@
|
||||
from .dummy_data_generator import DummyDataGenerator
|
||||
from .dummy_data_generator import DummyDataGenerator
|
||||
from .executor import run_fwd_bwd
|
||||
|
15
tests/components_to_test/utils/executor.py
Normal file
15
tests/components_to_test/utils/executor.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, use_init_ctx=False):
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
if use_init_ctx:
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
Reference in New Issue
Block a user