mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-30 22:24:21 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -17,14 +17,14 @@ def synthesize_data():
|
||||
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
colossalai.launch_from_torch(config="./config.py")
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# trace the model with meta data
|
||||
model = resnet50(num_classes=10).cuda()
|
||||
|
||||
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
||||
input_sample = {"x": torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to("meta")}
|
||||
device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True)
|
||||
model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True)
|
||||
|
||||
@@ -88,8 +88,9 @@ def main():
|
||||
|
||||
logger.info(
|
||||
f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user