address conversation

This commit is contained in:
YeAnbang 2025-05-16 14:15:35 +08:00
parent 6abffb9100
commit 203dfb1536
3 changed files with 10 additions and 4 deletions

View File

@ -66,7 +66,7 @@ def launch_distributed(
dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size

View File

@ -187,7 +187,7 @@ class BaseProducer:
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
)
eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
@ -220,7 +220,7 @@ class BaseProducer:
safe_append_to_jsonl_file(
os.path.join(
self.eval_save_dir,
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
),
eval_results,
)

View File

@ -104,7 +104,13 @@ if __name__ == "__main__":
choices=["think_answer_tags", "boxed"],
help="Reward type for GRPO.",
)
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
parser.add_argument(
"-ei",
"--eval-interval",
type=int,
default=100,
help="Interval for evaluation. Evaluate every ei training steps.",
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")