fix logprob, add filtering, temperature annealing, lr descent

This commit is contained in:
YeAnbang 2025-03-21 10:24:24 +08:00
parent 7ee4452f8c
commit 0472f44163
7 changed files with 74 additions and 27 deletions

View File

@ -57,6 +57,7 @@ class BaseConsumer:
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device()
self.lr_scheduler = None
def setup(self) -> None:
for i in range(self.num_producers):
@ -121,6 +122,8 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")

View File

@ -15,6 +15,7 @@ from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
@ -34,10 +35,10 @@ class GRPOConsumer(BaseConsumer):
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
num_generations=8,
use_wandb=True,
generator_config=None,
filter_range=None,
generate_config=None,
training_config={},
):
super().__init__(
num_producers,
@ -57,7 +58,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
@ -66,6 +67,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
self.generate_config = generate_config
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -74,7 +76,7 @@ class GRPOConsumer(BaseConsumer):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
self.filter_range = filter_range
self.filter_range = training_config.get("filter_range", None)
if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values."
@ -92,15 +94,21 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss()
self.global_step = 0
if use_wandb and self.rank == 0:
if "repetition_penalty" in generator_config:
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}"
else:
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}"
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
warmup_steps=0,
eta_min=0.1 * training_config.get("lr", 1e-6),
)
def setup(self):
super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
self.reference_model, *_ = self.booster.boost(self.reference_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
@ -133,7 +141,10 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(
policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
policy_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)
with torch.no_grad():
@ -142,7 +153,10 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
reference_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)
per_token_kl = (
@ -161,22 +175,24 @@ class GRPOConsumer(BaseConsumer):
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
if self.filter_range is None
else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1])
else torch.logical_and(
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
)
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations]
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = (reward - reward_mean) / (reward_std + 1e-4)
# Calculate Loss
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,

View File

@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path")
@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
self.num_generations = 8
self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
class SGLangInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if sgl is None:
raise ImportError("sglang is not installed")
path = model_config.pop("path")
@ -175,10 +187,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
n=8,
)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
@ -186,9 +203,10 @@ class VLLMInferenceBackend(BaseInferenceBackend):
self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:

View File

@ -42,6 +42,7 @@ def launch_distributed(
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
num_generations: int = 8,
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
@ -76,6 +77,7 @@ def launch_distributed(
tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size,
backend=inference_backend,
num_generations=num_generations,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
@ -99,7 +101,8 @@ def launch_distributed(
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
generate_config=generate_config_consumer,
filter_range=[0.05, 9.0],
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
num_generations=num_generations,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])

View File

@ -117,6 +117,12 @@ class BaseProducer:
None, self.num_producers, device=self.device, group_name="sync_model"
)
self.load_state_dict(state_dict)
# linear annealing for 1 episode, temperature from initial to 0.7
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
)
@ray.remote
@ -135,6 +141,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
):
super().__init__(
producer_idx,
@ -150,7 +157,7 @@ class SimpleProducer(BaseProducer):
microbatch_size,
backend,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):

View File

@ -22,7 +22,7 @@ if __name__ == "__main__":
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model)
generate_config = dict(top_k=50, top_p=0.9, temperature=0.7)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers":
inference_model_config.update(

View File

@ -387,7 +387,7 @@ def dist_log_prob(
dtype=dtype,
)
else:
log_prob = log_softmax(logits)
log_prob = log_softmax(logits, dim=-1)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
return log_prob