mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 04:18:55 +00:00
* [elixir] add elixir * [elixir] add unit tests * remove useless code * fix python 3.8 issue * fix typo * add test skip * add docstrings * add docstrings * add readme * fix typo
24 lines
466 B
Python
24 lines
466 B
Python
import torch
|
|
import torch.nn as nn
|
|
from torchvision.models import resnet18
|
|
|
|
from tests.test_elixir.utils.registry import TEST_MODELS
|
|
|
|
|
|
def resnet_data_fn():
|
|
return dict(x=torch.randn(4, 3, 32, 32))
|
|
|
|
|
|
class ResNetModel(nn.Module):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.r = resnet18()
|
|
|
|
def forward(self, x):
|
|
output = self.r(x)
|
|
return output.sum()
|
|
|
|
|
|
TEST_MODELS.register('resnet', ResNetModel, resnet_data_fn)
|