mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-23 05:06:26 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -13,19 +13,20 @@ from .registry import non_distributed_component_funcs
|
||||
|
||||
def get_cifar10_dataloader(train):
|
||||
# build dataloaders
|
||||
dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
train=train,
|
||||
transform=transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
|
||||
dataset = CIFAR10(
|
||||
root=Path(os.environ["DATA"]),
|
||||
download=True,
|
||||
train=train,
|
||||
transform=transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
|
||||
),
|
||||
)
|
||||
dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
|
||||
return dataloader
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='resnet18')
|
||||
@non_distributed_component_funcs.register(name="resnet18")
|
||||
def get_resnet_training_components():
|
||||
|
||||
def model_builder(checkpoint=False):
|
||||
return resnet18(num_classes=10)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user