fix orpo cross entropy loss

This commit is contained in:
YeAnbang
2024-07-15 02:12:05 +00:00
parent 115c4cc5a4
commit b3594d4d68
5 changed files with 19 additions and 15 deletions

View File

@@ -189,6 +189,8 @@ def train(args):
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:

View File

@@ -176,6 +176,8 @@ def train(args):
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:

View File

@@ -16,10 +16,13 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
def train(args):
# check lora compatibility
@@ -186,6 +189,8 @@ def train(args):
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch)

View File

@@ -187,6 +187,8 @@ def train(args):
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"