mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-25 00:08:32 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -33,16 +33,14 @@ class SimpleNet(CheckpointModule):
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
def generate(self):
|
||||
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
|
||||
label = torch.randint(low=0, high=2, size=(16,), device=get_current_device())
|
||||
return data, label
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='simple_net')
|
||||
@non_distributed_component_funcs.register(name="simple_net")
|
||||
def get_training_components():
|
||||
|
||||
def model_builder(checkpoint=False):
|
||||
return SimpleNet(checkpoint)
|
||||
|
||||
@@ -51,4 +49,5 @@ def get_training_components():
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
return model_builder, trainloader, testloader, HybridAdam, criterion
|
||||
|
||||
Reference in New Issue
Block a user