mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 04:32:47 +00:00
Merge pull request #6309 from hpcaitech/grpo-eval-dev
[feat] Support evaluation during training
This commit is contained in:
commit
3c42c0ce82
1
.gitignore
vendored
1
.gitignore
vendored
@ -165,3 +165,4 @@ applications/ColossalChat/logs
|
|||||||
applications/ColossalChat/tests/logs
|
applications/ColossalChat/tests/logs
|
||||||
applications/ColossalChat/wandb
|
applications/ColossalChat/wandb
|
||||||
applications/ColossalChat/model
|
applications/ColossalChat/model
|
||||||
|
applications/ColossalChat/eval
|
||||||
|
@ -93,7 +93,6 @@ class BaseConsumer:
|
|||||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
|
|
||||||
self.recv_cnt = 0
|
self.recv_cnt = 0
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
@ -209,6 +208,8 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
minibatch_size,
|
minibatch_size,
|
||||||
|
save_interval,
|
||||||
|
save_dir,
|
||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
@ -33,12 +33,13 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
plugin_config,
|
plugin_config,
|
||||||
minibatch_size=1,
|
minibatch_size=1,
|
||||||
num_generations=8,
|
num_generations=8,
|
||||||
use_wandb=True,
|
|
||||||
generate_config=None,
|
generate_config=None,
|
||||||
grpo_config={},
|
grpo_config={},
|
||||||
project_name=None,
|
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
|
project_name: str = None,
|
||||||
|
run_name: str = None,
|
||||||
|
wandb_group_name: str = None,
|
||||||
):
|
):
|
||||||
print(f"Using GRPO config: {grpo_config}")
|
print(f"Using GRPO config: {grpo_config}")
|
||||||
if (
|
if (
|
||||||
@ -84,6 +85,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.total_sample_count = 0
|
self.total_sample_count = 0
|
||||||
|
self.project_name = project_name
|
||||||
|
self.run_name = run_name
|
||||||
|
self.wandb_group_name = wandb_group_name
|
||||||
|
|
||||||
self.policy_loss_fn = PolicyLoss(
|
self.policy_loss_fn = PolicyLoss(
|
||||||
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
||||||
@ -134,7 +138,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
**reward_model_kwargs,
|
**reward_model_kwargs,
|
||||||
)
|
)
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.use_wandb = use_wandb
|
|
||||||
|
|
||||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
@ -145,13 +148,16 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
if self.use_wandb and (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
|
||||||
):
|
):
|
||||||
# Initialize wandb.
|
self.wandb_run = wandb.init(
|
||||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
project=self.project_name,
|
||||||
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
|
sync_tensorboard=False,
|
||||||
|
dir="./wandb",
|
||||||
|
name=self.run_name,
|
||||||
|
group=self.wandb_group_name,
|
||||||
|
)
|
||||||
|
|
||||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||||
@ -265,7 +271,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||||
self.effective_sample_count += effective_samples.item()
|
self.effective_sample_count += effective_samples.item()
|
||||||
self.total_sample_count += total_samples.item()
|
self.total_sample_count += total_samples.item()
|
||||||
|
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Global Step": self.global_step,
|
"Global Step": self.global_step,
|
||||||
@ -506,8 +511,8 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
}
|
}
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
||||||
|
if self.wandb_run is not None:
|
||||||
self.wandb_run.log(metrics)
|
self.wandb_run.log(metrics)
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_reward.zero_()
|
self.accum_reward.zero_()
|
||||||
self.accum_ans_acc.zero_()
|
self.accum_ans_acc.zero_()
|
||||||
@ -521,7 +526,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
# All gather excessive prompts index across DP ranks.
|
# All gather excessive prompts index across DP ranks.
|
||||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
||||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
||||||
|
|
||||||
return loss_scalar, excessive_prompts_idx
|
return loss_scalar, excessive_prompts_idx
|
||||||
else:
|
else:
|
||||||
return None, excessive_prompts_idx
|
return None, excessive_prompts_idx
|
||||||
@ -530,4 +534,5 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_model._force_wait_all_gather()
|
self.policy_model._force_wait_all_gather()
|
||||||
model = self.policy_model.unwrap()
|
model = self.policy_model.unwrap()
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
generate_config = generate_config.copy()
|
generate_config = generate_config.copy()
|
||||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
generate_config.update({"n": num_generations})
|
generate_config.update({"n": num_generations})
|
||||||
self.generate_config = SamplingParams(**generate_config)
|
self.generate_config = generate_config
|
||||||
|
self.sample_params = SamplingParams(**generate_config)
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
micro_batch_input_ids_no_padding = [
|
micro_batch_input_ids_no_padding = [
|
||||||
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
||||||
]
|
]
|
||||||
|
sample_params = kwargs.get("sample_params", self.sample_params)
|
||||||
outputs = self.llm.generate(
|
outputs = self.llm.generate(
|
||||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
|
||||||
)
|
)
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
out_len = []
|
out_len = []
|
||||||
@ -236,7 +238,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
log_probs.append(p)
|
log_probs.append(p)
|
||||||
|
|
||||||
# pad them
|
# pad them
|
||||||
max_len = self.generate_config.max_tokens
|
max_len = self.sample_params.max_tokens
|
||||||
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
|
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
|
||||||
|
|
||||||
for i, new_token_ids in enumerate(out_tokens):
|
for i, new_token_ids in enumerate(out_tokens):
|
||||||
@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
"response_idx": response_idx,
|
"response_idx": response_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
|
||||||
|
|
||||||
if "gt_answer" in kwargs:
|
if "gt_answer" in kwargs:
|
||||||
# repeat gt_answer for each prompt.
|
# repeat gt_answer for each prompt.
|
||||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
|
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
|
||||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import uuid
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
@ -34,7 +35,7 @@ def launch_distributed(
|
|||||||
inference_microbatch_size: int,
|
inference_microbatch_size: int,
|
||||||
train_batch_size: int,
|
train_batch_size: int,
|
||||||
train_minibatch_size: int,
|
train_minibatch_size: int,
|
||||||
dataset_config: Dict[str, Any],
|
train_dataset_config: Dict[str, Any],
|
||||||
dataloaders_config: Dict[str, Any],
|
dataloaders_config: Dict[str, Any],
|
||||||
inference_model_config: Dict[str, Any],
|
inference_model_config: Dict[str, Any],
|
||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
@ -50,8 +51,11 @@ def launch_distributed(
|
|||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir: str = "./model",
|
save_dir: str = "./model",
|
||||||
|
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||||
|
eval_interval: int = 100,
|
||||||
|
eval_save_dir: Optional[str] = None,
|
||||||
|
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||||
else:
|
else:
|
||||||
@ -60,12 +64,15 @@ def launch_distributed(
|
|||||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||||
|
|
||||||
dataset_path = dataset_config["path"]
|
dataset_path = train_dataset_config["path"]
|
||||||
num_samples = get_jsonl_size_fast(dataset_path)
|
num_samples = get_jsonl_size_fast(dataset_path)
|
||||||
global_inference_batch_size = inference_batch_size * num_producers
|
global_inference_batch_size = inference_batch_size * num_producers
|
||||||
num_update_per_episode = num_samples // global_inference_batch_size
|
num_update_per_episode = num_samples // global_inference_batch_size
|
||||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||||
|
|
||||||
|
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||||
|
wandb_group_name = str(uuid.uuid4())
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
for i in range(num_producers):
|
for i in range(num_producers):
|
||||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||||
@ -74,7 +81,7 @@ def launch_distributed(
|
|||||||
num_consumer_procs=num_consumer_procs,
|
num_consumer_procs=num_consumer_procs,
|
||||||
num_episodes=num_episodes,
|
num_episodes=num_episodes,
|
||||||
batch_size=inference_batch_size,
|
batch_size=inference_batch_size,
|
||||||
dataset_config=dataset_config,
|
train_dataset_config=train_dataset_config,
|
||||||
dataloaders_config=dataloaders_config,
|
dataloaders_config=dataloaders_config,
|
||||||
model_config=inference_model_config,
|
model_config=inference_model_config,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
@ -83,6 +90,14 @@ def launch_distributed(
|
|||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
consumer_plugin_config=plugin_config,
|
consumer_plugin_config=plugin_config,
|
||||||
|
eval_dataset_config=eval_dataset_config,
|
||||||
|
eval_interval=eval_interval,
|
||||||
|
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||||
|
eval_save_dir=eval_save_dir,
|
||||||
|
eval_generation_config=eval_generation_config,
|
||||||
|
project_name=project_name,
|
||||||
|
run_name=run_name,
|
||||||
|
wandb_group_name=wandb_group_name,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
procs.append(producer)
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
@ -108,9 +123,11 @@ def launch_distributed(
|
|||||||
generate_config=generate_config_consumer,
|
generate_config=generate_config_consumer,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
project_name=project_name,
|
|
||||||
save_interval=save_interval,
|
save_interval=save_interval,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
|
project_name=project_name,
|
||||||
|
run_name=run_name,
|
||||||
|
wandb_group_name=wandb_group_name,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -1,9 +1,16 @@
|
|||||||
|
import copy
|
||||||
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import wandb
|
||||||
from coati.dataset.loader import RawConversationDataset
|
from coati.dataset.loader import RawConversationDataset
|
||||||
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
|
from ray.util.collective import allreduce
|
||||||
|
from ray.util.collective.types import Backend, ReduceOp
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -11,7 +18,12 @@ from colossalai.utils import get_current_device
|
|||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .inference_backend import BACKEND_MAP
|
from .inference_backend import BACKEND_MAP
|
||||||
from .utils import pre_send
|
from .utils import pre_send, safe_append_to_jsonl_file
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import SamplingParams
|
||||||
|
except ImportError:
|
||||||
|
LLM = None
|
||||||
|
|
||||||
|
|
||||||
class BaseProducer:
|
class BaseProducer:
|
||||||
@ -22,7 +34,7 @@ class BaseProducer:
|
|||||||
num_consumer_procs: int,
|
num_consumer_procs: int,
|
||||||
num_episodes: int,
|
num_episodes: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
dataset_config: Dict[str, Any],
|
train_dataset_config: Dict[str, Any],
|
||||||
dataloaders_config: Dict[str, Any],
|
dataloaders_config: Dict[str, Any],
|
||||||
model_config: Dict[str, Any],
|
model_config: Dict[str, Any],
|
||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
@ -30,6 +42,14 @@ class BaseProducer:
|
|||||||
microbatch_size: int = 1,
|
microbatch_size: int = 1,
|
||||||
backend: str = "transformers",
|
backend: str = "transformers",
|
||||||
consumer_plugin_config: Dict[str, Any] = None,
|
consumer_plugin_config: Dict[str, Any] = None,
|
||||||
|
eval_dataset_config=None,
|
||||||
|
eval_interval=-1, # disable evaluation
|
||||||
|
evaluation_function_type="think_answer_tags",
|
||||||
|
eval_save_dir: str = "./eval",
|
||||||
|
project_name: str = None,
|
||||||
|
run_name: str = None,
|
||||||
|
wandb_group_name: str = None,
|
||||||
|
wandb_log_rollout_interval: int = 20,
|
||||||
):
|
):
|
||||||
self.producer_idx = producer_idx
|
self.producer_idx = producer_idx
|
||||||
self.num_producers = num_producers
|
self.num_producers = num_producers
|
||||||
@ -39,11 +59,31 @@ class BaseProducer:
|
|||||||
self.microbatch_size = microbatch_size
|
self.microbatch_size = microbatch_size
|
||||||
assert batch_size % microbatch_size == 0
|
assert batch_size % microbatch_size == 0
|
||||||
self.num_microbatches = batch_size // microbatch_size
|
self.num_microbatches = batch_size // microbatch_size
|
||||||
|
self.latest_eval_step = -1
|
||||||
|
|
||||||
self.dataset_config = dataset_config
|
self.train_dataset_config = train_dataset_config
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.tokenizer_config = tokenizer_config
|
self.tokenizer_config = tokenizer_config
|
||||||
|
self.consumer_plugin_config = consumer_plugin_config
|
||||||
|
self.eval_interval = eval_interval
|
||||||
|
self.eval_save_dir = eval_save_dir
|
||||||
|
self.consumer_global_step = 0
|
||||||
|
self.eval_mode = False
|
||||||
|
self.wandb_rollout_data = []
|
||||||
|
self.wandb_log_rollout_interval = wandb_log_rollout_interval
|
||||||
|
self.latest_rollout_log_step = -1
|
||||||
|
if self.producer_idx == 0:
|
||||||
|
self.wandb_run = wandb.init(
|
||||||
|
project=project_name,
|
||||||
|
sync_tensorboard=False,
|
||||||
|
dir="./wandb",
|
||||||
|
name=run_name + "_eval",
|
||||||
|
group=wandb_group_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
|
||||||
|
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
if tokenizer_config is None:
|
if tokenizer_config is None:
|
||||||
@ -55,13 +95,13 @@ class BaseProducer:
|
|||||||
self.tokenizer.padding_side = "left"
|
self.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
# init dataloader
|
# init dataloader
|
||||||
dataset_path = dataset_config.pop("path")
|
train_dataset_path = train_dataset_config.pop("path")
|
||||||
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
|
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
|
||||||
self.dataloader = DataLoader(
|
self.train_dataloader = DataLoader(
|
||||||
self.dataset,
|
self.train_dataset,
|
||||||
batch_size=microbatch_size,
|
batch_size=microbatch_size,
|
||||||
sampler=DistributedSampler(
|
sampler=DistributedSampler(
|
||||||
self.dataset,
|
self.train_dataset,
|
||||||
num_replicas=num_producers,
|
num_replicas=num_producers,
|
||||||
rank=producer_idx,
|
rank=producer_idx,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
@ -71,6 +111,36 @@ class BaseProducer:
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.eval_dataset_config = eval_dataset_config
|
||||||
|
if self.eval_dataset_config is not None:
|
||||||
|
self.eval_dataloaders = {}
|
||||||
|
for eval_task_name in self.eval_dataset_config:
|
||||||
|
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
|
||||||
|
eval_dataset = RawConversationDataset(
|
||||||
|
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
|
||||||
|
)
|
||||||
|
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
|
||||||
|
self.eval_dataloaders[eval_task_name] = DataLoader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=microbatch_size,
|
||||||
|
sampler=DistributedSampler(
|
||||||
|
eval_dataset,
|
||||||
|
num_replicas=num_producers,
|
||||||
|
rank=producer_idx,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
seed=42,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if evaluation_function_type == "think_answer_tags":
|
||||||
|
self.evaluation_function = math_reward_fn
|
||||||
|
elif evaluation_function_type == "boxed":
|
||||||
|
self.evaluation_function = boxed_math_reward_fn
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||||
|
else:
|
||||||
|
raise ValueError("eval_dataset_config is not defined")
|
||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
|
|
||||||
# init backend
|
# init backend
|
||||||
@ -82,6 +152,12 @@ class BaseProducer:
|
|||||||
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
|
cc.init_collective_group(
|
||||||
|
world_size=self.num_producers,
|
||||||
|
rank=self.producer_idx,
|
||||||
|
backend=Backend.NCCL,
|
||||||
|
group_name="producer_group",
|
||||||
|
)
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
for i in range(self.consumer_pp_size):
|
for i in range(self.consumer_pp_size):
|
||||||
@ -96,29 +172,74 @@ class BaseProducer:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def loop(self) -> None:
|
def loop(self) -> None:
|
||||||
num_update_per_episode = len(self.dataloader) // self.num_microbatches
|
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
|
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
|
||||||
)
|
)
|
||||||
for episode in range(self.num_episodes):
|
for episode in range(self.num_episodes):
|
||||||
self.dataloader.sampler.set_epoch(episode)
|
self.train_dataloader.sampler.set_epoch(episode)
|
||||||
for i, batch in enumerate(self.dataloader):
|
for i, batch in enumerate(self.train_dataloader):
|
||||||
if i >= num_valid_microbatches:
|
if i >= num_valid_microbatches:
|
||||||
break
|
break
|
||||||
|
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||||
|
if (
|
||||||
|
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
|
||||||
|
and self.consumer_global_step > self.latest_eval_step
|
||||||
|
) or self.latest_eval_step == -1:
|
||||||
|
to_log_msg = {}
|
||||||
|
self.eval_mode = True
|
||||||
|
for eval_task_name in self.eval_dataloaders:
|
||||||
|
if self.producer_idx == 0:
|
||||||
|
print(
|
||||||
|
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
|
||||||
|
)
|
||||||
|
eval_results = []
|
||||||
|
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||||
|
for eval_batch in tqdm.tqdm(
|
||||||
|
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||||
|
):
|
||||||
|
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||||
|
eval_results = eval_results + [
|
||||||
|
self.evaluation_function(
|
||||||
|
eval_outputs["input_ids"][m][n],
|
||||||
|
eval_outputs["gt_answer"][m][n],
|
||||||
|
eval_outputs["response_idx"][m][n],
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
eval_mode=True,
|
||||||
|
)
|
||||||
|
for m in range(eval_outputs["input_ids"].size(0))
|
||||||
|
for n in range(eval_outputs["input_ids"].size(1))
|
||||||
|
]
|
||||||
|
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
|
||||||
|
eval_statistics_tensor[1] += len(eval_results)
|
||||||
|
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
|
||||||
|
to_log_msg[f"eval/{eval_task_name}"] = (
|
||||||
|
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
|
||||||
|
)
|
||||||
|
if self.producer_idx == 0:
|
||||||
|
print(
|
||||||
|
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
|
||||||
|
)
|
||||||
|
# save eval results
|
||||||
|
safe_append_to_jsonl_file(
|
||||||
|
os.path.join(
|
||||||
|
self.eval_save_dir,
|
||||||
|
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
|
||||||
|
),
|
||||||
|
eval_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.producer_idx == 0:
|
||||||
|
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||||
|
self.eval_mode = False
|
||||||
|
self.latest_eval_step = self.consumer_global_step
|
||||||
outputs = self.rollout(**batch)
|
outputs = self.rollout(**batch)
|
||||||
|
|
||||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
outputs["temperature"] = torch.tensor(
|
outputs["temperature"] = torch.tensor(
|
||||||
[
|
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||||
(
|
|
||||||
self.model.generate_config["temperature"]
|
|
||||||
if isinstance(self.model.generate_config.temperature, dict)
|
|
||||||
else self.model.generate_config.temperature
|
|
||||||
)
|
|
||||||
]
|
|
||||||
* outputs["input_ids"].size(0)
|
|
||||||
).to(outputs["input_ids"].device)
|
).to(outputs["input_ids"].device)
|
||||||
outputs = pre_send(outputs)
|
outputs = pre_send(outputs)
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
@ -142,6 +263,8 @@ class BaseProducer:
|
|||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||||
)
|
)
|
||||||
|
if "consumer_global_step" in state_dict:
|
||||||
|
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
@ -150,6 +273,8 @@ class BaseProducer:
|
|||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||||
)
|
)
|
||||||
|
if "consumer_global_step" in state_dict:
|
||||||
|
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -159,13 +284,12 @@ class BaseProducer:
|
|||||||
self.model.llm.wake_up()
|
self.model.llm.wake_up()
|
||||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||||
if episode <= 0:
|
if episode <= 0:
|
||||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||||
if isinstance(self.model.generate_config.temperature, dict):
|
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
"temperature"
|
||||||
"temperature"
|
] + ratio * 0.9
|
||||||
] + ratio * 0.9
|
if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||||
else:
|
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
|
||||||
"temperature"
|
"temperature"
|
||||||
] + ratio * 0.9
|
] + ratio * 0.9
|
||||||
|
|
||||||
@ -179,7 +303,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
num_consumer_procs,
|
num_consumer_procs,
|
||||||
num_episodes,
|
num_episodes,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_config,
|
train_dataset_config,
|
||||||
dataloaders_config,
|
dataloaders_config,
|
||||||
model_config,
|
model_config,
|
||||||
generate_config,
|
generate_config,
|
||||||
@ -188,6 +312,14 @@ class SimpleProducer(BaseProducer):
|
|||||||
backend="transformers",
|
backend="transformers",
|
||||||
num_generations: int = 8,
|
num_generations: int = 8,
|
||||||
consumer_plugin_config=None,
|
consumer_plugin_config=None,
|
||||||
|
eval_dataset_config=None,
|
||||||
|
eval_interval=-1, # disable evaluation
|
||||||
|
evaluation_function_type="think_answer_tags",
|
||||||
|
eval_save_dir: str = "./eval",
|
||||||
|
eval_generation_config={},
|
||||||
|
project_name: str = None,
|
||||||
|
run_name: str = None,
|
||||||
|
wandb_group_name: str = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
producer_idx,
|
producer_idx,
|
||||||
@ -195,7 +327,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
num_consumer_procs,
|
num_consumer_procs,
|
||||||
num_episodes,
|
num_episodes,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_config,
|
train_dataset_config,
|
||||||
dataloaders_config,
|
dataloaders_config,
|
||||||
model_config,
|
model_config,
|
||||||
generate_config,
|
generate_config,
|
||||||
@ -203,15 +335,43 @@ class SimpleProducer(BaseProducer):
|
|||||||
microbatch_size,
|
microbatch_size,
|
||||||
backend,
|
backend,
|
||||||
consumer_plugin_config,
|
consumer_plugin_config,
|
||||||
|
eval_dataset_config=eval_dataset_config,
|
||||||
|
eval_interval=eval_interval,
|
||||||
|
evaluation_function_type=evaluation_function_type,
|
||||||
|
eval_save_dir=eval_save_dir,
|
||||||
|
project_name=project_name,
|
||||||
|
run_name=run_name,
|
||||||
|
wandb_group_name=wandb_group_name,
|
||||||
)
|
)
|
||||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||||
|
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||||
|
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||||
|
self.eval_generation_config.update(eval_generation_config)
|
||||||
|
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||||
if self.producer_idx == 1:
|
if self.producer_idx == 0 and not self.eval_mode:
|
||||||
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
wandb_rollout_data = self.wandb_rollout_data + [
|
||||||
|
[
|
||||||
|
str(self.consumer_global_step),
|
||||||
|
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
|
||||||
|
or self.latest_rollout_log_step == -1
|
||||||
|
):
|
||||||
|
self.wandb_rollout_data = wandb_rollout_data
|
||||||
|
self.latest_rollout_log_step = self.consumer_global_step
|
||||||
|
self.wandb_run.log(
|
||||||
|
{
|
||||||
|
"rollout/rollout_examples": wandb.Table(
|
||||||
|
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
return rollouts
|
return rollouts
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
|
@ -5,6 +5,7 @@ from .reward_utils import extract_boxed_solution, extract_solution, validate_res
|
|||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
|
eval_mode = kwargs.get("eval_mode", False)
|
||||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||||
acc_score = 10.0
|
acc_score = 10.0
|
||||||
reward = torch.tensor(0.0)
|
reward = torch.tensor(0.0)
|
||||||
@ -44,36 +45,23 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
|
||||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
if not eval_mode:
|
||||||
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
|
|
||||||
def gsm8k_reward_fn(input_ids, **kwargs):
|
|
||||||
gt_answer = kwargs["gt_answer"]
|
|
||||||
tokenizer = kwargs["tokenizer"]
|
|
||||||
s, e = kwargs["response_start"], kwargs["response_end"]
|
|
||||||
reward = torch.tensor(0.0).to(input_ids.device)
|
|
||||||
if gt_answer is None:
|
|
||||||
return reward
|
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
|
||||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
|
||||||
is_valid = True
|
|
||||||
try:
|
|
||||||
int(final_answer.strip())
|
|
||||||
except Exception:
|
|
||||||
is_valid = False
|
|
||||||
|
|
||||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
|
||||||
if not is_valid or not format_valid:
|
|
||||||
return reward
|
|
||||||
else:
|
else:
|
||||||
reward += 1.0
|
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
||||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
return {
|
||||||
reward = reward + 9.0
|
"prompt": prompt,
|
||||||
return reward
|
"prediction": decoded_final_answer,
|
||||||
|
"gold": gt_answer,
|
||||||
|
"parsed": final_answer,
|
||||||
|
"format_valid": format_acc.item(),
|
||||||
|
"ans_valid": ans_acc.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
|
eval_mode = kwargs.get("eval_mode", False)
|
||||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||||
format_score = 0.0
|
format_score = 0.0
|
||||||
acc_score = 10.0
|
acc_score = 10.0
|
||||||
@ -91,7 +79,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||||
|
|
||||||
if gt_answer is None:
|
if gt_answer is None:
|
||||||
return reward
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
|
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
@ -108,5 +96,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
reward += acc_score
|
reward += acc_score
|
||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
if not eval_mode:
|
||||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
|
else:
|
||||||
|
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
||||||
|
return {
|
||||||
|
"prompt": prompt,
|
||||||
|
"prediction": decoded_final_answer,
|
||||||
|
"gold": gt_answer,
|
||||||
|
"parsed": final_answer,
|
||||||
|
"format_valid": format_acc.item(),
|
||||||
|
"ans_valid": ans_acc.item(),
|
||||||
|
}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from filelock import FileLock
|
||||||
|
|
||||||
from colossalai.shardformer.layer.loss import dist_log_prob
|
from colossalai.shardformer.layer.loss import dist_log_prob
|
||||||
|
|
||||||
@ -130,3 +133,13 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.
|
|||||||
"""
|
"""
|
||||||
tensor = tensor * mask
|
tensor = tensor * mask
|
||||||
return tensor.sum(dim=dim)
|
return tensor.sum(dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_append_to_jsonl_file(file_path, data):
|
||||||
|
with FileLock(file_path + ".lock"):
|
||||||
|
# Ensure file exists
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
with open(file_path, "a", encoding="utf8") as f:
|
||||||
|
for entry in data:
|
||||||
|
json_line = json.dumps(entry, ensure_ascii=False)
|
||||||
|
f.write(json_line + "\n")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
@ -9,7 +10,16 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||||
parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.")
|
parser.add_argument(
|
||||||
|
"-ed",
|
||||||
|
"--eval-dataset",
|
||||||
|
type=str,
|
||||||
|
default='{"eval task name":"data_eval.jsonl"}',
|
||||||
|
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
|
||||||
|
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
|
||||||
|
The key is the task name, and the value is the path to the jsonl file",
|
||||||
|
)
|
||||||
|
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||||
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
||||||
|
|
||||||
# Distributed training parameters
|
# Distributed training parameters
|
||||||
@ -94,11 +104,20 @@ if __name__ == "__main__":
|
|||||||
choices=["think_answer_tags", "boxed"],
|
choices=["think_answer_tags", "boxed"],
|
||||||
help="Reward type for GRPO.",
|
help="Reward type for GRPO.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-ei",
|
||||||
|
"--eval-interval",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Interval for evaluation. Evaluate every ei training steps.",
|
||||||
|
)
|
||||||
|
|
||||||
# Logging/Checkpointing parameters
|
# Logging/Checkpointing parameters
|
||||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||||
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
|
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.train_minibatch_size is None:
|
if args.train_minibatch_size is None:
|
||||||
@ -148,6 +167,7 @@ if __name__ == "__main__":
|
|||||||
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
dict(
|
dict(
|
||||||
@ -166,6 +186,7 @@ if __name__ == "__main__":
|
|||||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
|
|
||||||
@ -208,7 +229,7 @@ if __name__ == "__main__":
|
|||||||
inference_microbatch_size=args.inference_microbatch_size,
|
inference_microbatch_size=args.inference_microbatch_size,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
train_minibatch_size=args.train_minibatch_size,
|
train_minibatch_size=args.train_minibatch_size,
|
||||||
dataset_config={
|
train_dataset_config={
|
||||||
"path": args.dataset,
|
"path": args.dataset,
|
||||||
"max_length": args.max_prompt_tokens,
|
"max_length": args.max_prompt_tokens,
|
||||||
"system_prompt": args.system_prompt,
|
"system_prompt": args.system_prompt,
|
||||||
@ -238,4 +259,11 @@ if __name__ == "__main__":
|
|||||||
project_name=args.project,
|
project_name=args.project,
|
||||||
save_interval=args.save_interval,
|
save_interval=args.save_interval,
|
||||||
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
||||||
|
eval_dataset_config={
|
||||||
|
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
|
||||||
|
for k, v in json.loads(args.eval_dataset).items()
|
||||||
|
},
|
||||||
|
eval_interval=args.eval_interval,
|
||||||
|
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||||
|
eval_generation_config=eval_generation_config,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user