mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
fix logging rollouts
This commit is contained in:
parent
03b41d6fb5
commit
107470a360
1
.gitignore
vendored
1
.gitignore
vendored
@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs
|
||||
applications/ColossalChat/wandb
|
||||
applications/ColossalChat/model
|
||||
applications/ColossalChat/eval
|
||||
applications/ColossalChat/rollouts
|
||||
|
@ -120,12 +120,16 @@ class GRPOConsumer(BaseConsumer):
|
||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||
)
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
response_format_tags = (
|
||||
{
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
if grpo_config.get("reward_fn_type") == "think_answer_tags"
|
||||
else None
|
||||
)
|
||||
reward_model_kwargs = {
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
|
@ -55,6 +55,8 @@ def launch_distributed(
|
||||
eval_interval: int = 100,
|
||||
eval_save_dir: Optional[str] = None,
|
||||
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
):
|
||||
if core_algo not in ALGO_MAP:
|
||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||
@ -98,6 +100,8 @@ def launch_distributed(
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
)
|
||||
procs.append(producer)
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@ -49,7 +50,8 @@ class BaseProducer:
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
wandb_log_rollout_interval: int = 20,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
@ -70,9 +72,16 @@ class BaseProducer:
|
||||
self.eval_save_dir = eval_save_dir
|
||||
self.consumer_global_step = 0
|
||||
self.eval_mode = False
|
||||
self.wandb_rollout_data = []
|
||||
self.wandb_log_rollout_interval = wandb_log_rollout_interval
|
||||
self.log_rollout_interval = log_rollout_interval
|
||||
self.latest_rollout_log_step = -1
|
||||
if producer_idx == 0:
|
||||
if os.path.exists(rollout_log_file):
|
||||
raise ValueError(
|
||||
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
|
||||
)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
|
||||
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run = wandb.init(
|
||||
project=project_name,
|
||||
@ -320,6 +329,8 @@ class SimpleProducer(BaseProducer):
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
@ -342,6 +353,8 @@ class SimpleProducer(BaseProducer):
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
@ -353,26 +366,31 @@ class SimpleProducer(BaseProducer):
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
if self.producer_idx == 0 and not self.eval_mode:
|
||||
wandb_rollout_data = self.wandb_rollout_data + [
|
||||
[
|
||||
str(self.consumer_global_step),
|
||||
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
|
||||
]
|
||||
]
|
||||
if (
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
|
||||
or self.latest_rollout_log_step == -1
|
||||
):
|
||||
self.wandb_rollout_data = wandb_rollout_data
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"rollout/rollout_examples": wandb.Table(
|
||||
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
|
||||
new_record = (
|
||||
json.dumps(
|
||||
{
|
||||
"train_step": self.consumer_global_step,
|
||||
"rollout": self.tokenizer.batch_decode(
|
||||
rollouts["input_ids"][:, 0], skip_special_tokens=True
|
||||
),
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
self.rollout_log_file.write(new_record)
|
||||
self.rollout_log_file.flush()
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
return rollouts
|
||||
|
||||
def __del__(self):
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.finish()
|
||||
if hasattr(self, "rollout_log_file"):
|
||||
self.rollout_log_file.close()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
@ -118,6 +118,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.train_minibatch_size is None:
|
||||
@ -269,4 +272,6 @@ if __name__ == "__main__":
|
||||
eval_interval=args.eval_interval,
|
||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||
eval_generation_config=eval_generation_config,
|
||||
log_rollout_interval=20,
|
||||
rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user