mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 18:00:43 +00:00
address conversation
This commit is contained in:
parent
6abffb9100
commit
203dfb1536
@ -66,7 +66,7 @@ def launch_distributed(
|
||||
|
||||
dataset_path = train_dataset_config["path"]
|
||||
num_samples = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
num_update_per_episode = num_samples // global_inference_batch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||
|
||||
|
@ -187,7 +187,7 @@ class BaseProducer:
|
||||
for eval_task_name in self.eval_dataloaders:
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
|
||||
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||
@ -220,7 +220,7 @@ class BaseProducer:
|
||||
safe_append_to_jsonl_file(
|
||||
os.path.join(
|
||||
self.eval_save_dir,
|
||||
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
|
||||
),
|
||||
eval_results,
|
||||
)
|
||||
|
@ -104,7 +104,13 @@ if __name__ == "__main__":
|
||||
choices=["think_answer_tags", "boxed"],
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
||||
parser.add_argument(
|
||||
"-ei",
|
||||
"--eval-interval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Interval for evaluation. Evaluate every ei training steps.",
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
|
Loading…
Reference in New Issue
Block a user