update ignore

This commit is contained in:
Tong Li 2024-08-15 05:52:50 +00:00
parent 4516a4ed6a
commit 10b72a32b1
3 changed files with 12 additions and 13 deletions

View File

@ -161,3 +161,9 @@ applications/ColossalChat/sft_data
applications/ColossalChat/prompt_data applications/ColossalChat/prompt_data
applications/ColossalChat/preference_data applications/ColossalChat/preference_data
applications/ColossalChat/temp applications/ColossalChat/temp
# Testing data
/kto_data/
/preference_data/
/prompt_data/
/sft_data/

View File

@ -114,9 +114,7 @@ class SFTTrainer(SLTrainer):
) )
loss = outputs["loss"] loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
global_loss = all_reduce_mean(loss, self.booster) step_bar.set_postfix({"train/loss": loss.item()})
step_bar.set_postfix({"train/loss": global_loss.item()})
step_bar.update()
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
else: else:
@ -200,9 +198,8 @@ class SFTTrainer(SLTrainer):
) )
loss = outputs["loss"] loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
global_loss = all_reduce_mean(loss, self.booster) step_bar.set_postfix({"eval/loss": loss.item()})
step_bar.set_postfix({"eval/loss": global_loss.item()}) self.accumulative_meter.add("loss", loss.item())
self.accumulative_meter.add("loss", global_loss.item())
step_bar.update() step_bar.update()
if dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:

View File

@ -87,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
return tree_map(_to, x) return tree_map(_to, x)
def all_reduce_mean(tensor: torch.Tensor, booster: Booster) -> torch.Tensor: def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
""" """
Perform all-reduce operation on the given tensor and compute the mean across all processes. Perform all-reduce operation on the given tensor and compute the mean across all processes.
@ -97,12 +97,8 @@ def all_reduce_mean(tensor: torch.Tensor, booster: Booster) -> torch.Tensor:
Returns: Returns:
torch.Tensor: The reduced tensor with mean computed across all processes. torch.Tensor: The reduced tensor with mean computed across all processes.
""" """
if booster is not None: dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) tensor.div_(dist.get_world_size())
tensor.div_(booster.plugin.dp_size)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor return tensor