mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
[chat] typo accimulation_steps -> accumulation_steps (#3662)
This commit is contained in:
@@ -41,10 +41,10 @@ class SFTTrainer(Trainer):
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader = None,
|
||||
max_epochs: int = 2,
|
||||
accimulation_steps: int = 8,
|
||||
accumulation_steps: int = 8,
|
||||
callbacks: List[Callback] = [],
|
||||
) -> None:
|
||||
if accimulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
|
||||
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
|
||||
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
|
||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
||||
self.train_dataloader = train_dataloader
|
||||
@@ -52,8 +52,8 @@ class SFTTrainer(Trainer):
|
||||
self.model = model
|
||||
self.optimizer = optim
|
||||
|
||||
self.accimulation_steps = accimulation_steps
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
|
||||
self.accumulation_steps = accumulation_steps
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
|
||||
max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
self.scheduler = get_scheduler("cosine",
|
||||
@@ -67,7 +67,7 @@ class SFTTrainer(Trainer):
|
||||
wandb.watch(self.model)
|
||||
total_loss = 0
|
||||
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
|
||||
step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.max_epochs),
|
||||
step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
|
||||
desc=f'steps',
|
||||
disable=not is_rank_0())
|
||||
for epoch in range(self.max_epochs):
|
||||
@@ -85,20 +85,20 @@ class SFTTrainer(Trainer):
|
||||
if loss >= 2.5 and is_rank_0():
|
||||
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
|
||||
|
||||
loss = loss / self.accimulation_steps
|
||||
loss = loss / self.accumulation_steps
|
||||
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# gradient accumulation
|
||||
if (batch_id + 1) % self.accimulation_steps == 0:
|
||||
if (batch_id + 1) % self.accumulation_steps == 0:
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
if is_rank_0() and use_wandb:
|
||||
wandb.log({
|
||||
"loss": total_loss / self.accimulation_steps,
|
||||
"loss": total_loss / self.accumulation_steps,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
"epoch": epoch,
|
||||
"batch_id": batch_id
|
||||
|
Reference in New Issue
Block a user