mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -18,15 +18,15 @@ def data_gen():
|
||||
|
||||
def data_gen_for_image_classification():
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0])
|
||||
data["labels"] = torch.tensor([0])
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_masked_image_modeling():
|
||||
data = data_gen()
|
||||
num_patches = (config.image_size // config.patch_size)**2
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||
data['bool_masked_pos'] = bool_masked_pos
|
||||
data["bool_masked_pos"] = bool_masked_pos
|
||||
return data
|
||||
|
||||
|
||||
@@ -42,23 +42,29 @@ loss_fn_for_masked_image_modeling = lambda x: x.loss
|
||||
# transformers.ViTModel,
|
||||
# transformers.ViTForMaskedImageModeling,
|
||||
# transformers.ViTForImageClassification,
|
||||
model_zoo.register(name='transformers_vit',
|
||||
model_fn=lambda: transformers.ViTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_vit_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit",
|
||||
model_fn=lambda: transformers.ViTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_vit_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_masked_image_modeling',
|
||||
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_masked_image_modeling,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit_for_masked_image_modeling",
|
||||
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_masked_image_modeling,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_image_classification',
|
||||
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||
data_gen_fn=data_gen_for_image_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_image_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit_for_image_classification",
|
||||
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||
data_gen_fn=data_gen_for_image_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_image_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
Reference in New Issue
Block a user