[chat] typo accimulation_steps -> accumulation_steps (#3662)

This commit is contained in:
tanitna
2023-04-28 00:42:57 -07:00
committed by GitHub
parent 816add7e7f
commit 1a60dc07a8
6 changed files with 18 additions and 18 deletions

View File

@@ -41,10 +41,10 @@ class SFTTrainer(Trainer):
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
max_epochs: int = 2,
accimulation_steps: int = 8,
accumulation_steps: int = 8,
callbacks: List[Callback] = [],
) -> None:
if accimulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader
@@ -52,8 +52,8 @@ class SFTTrainer(Trainer):
self.model = model
self.optimizer = optim
self.accimulation_steps = accimulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
self.accumulation_steps = accumulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
self.scheduler = get_scheduler("cosine",
@@ -67,7 +67,7 @@ class SFTTrainer(Trainer):
wandb.watch(self.model)
total_loss = 0
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.max_epochs),
step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
desc=f'steps',
disable=not is_rank_0())
for epoch in range(self.max_epochs):
@@ -85,20 +85,20 @@ class SFTTrainer(Trainer):
if loss >= 2.5 and is_rank_0():
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
loss = loss / self.accimulation_steps
loss = loss / self.accumulation_steps
self.strategy.backward(loss, self.model, self.optimizer)
total_loss += loss.item()
# gradient accumulation
if (batch_id + 1) % self.accimulation_steps == 0:
if (batch_id + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and use_wandb:
wandb.log({
"loss": total_loss / self.accimulation_steps,
"loss": total_loss / self.accumulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id