mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
fix some typo with colossalai/device colossalai/tensor/ etc. (#4171)
Co-authored-by: flybird11111 <1829166702@qq.com>
This commit is contained in:
@@ -43,7 +43,7 @@ def data_gen_for_t5_model():
|
||||
# output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
# define loss function
|
||||
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn_for_conditional_generation = lambda x: x.loss
|
||||
|
@@ -64,7 +64,7 @@ def check_torch_ddp_no_sync():
|
||||
model = DummyModel()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||
# create a custom dasetset with 0 to 10
|
||||
# create a custom dataset with 0 to 10
|
||||
dataset = torch.arange(0, 10)
|
||||
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
|
||||
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
|
||||
|
@@ -15,7 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
# test baisc fsdp function
|
||||
# test basic fsdp function
|
||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
plugin = TorchFSDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
Reference in New Issue
Block a user