fix logprob, add filtering, temperature annealing, lr descent

This commit is contained in:
YeAnbang 2025-03-21 10:24:24 +08:00
parent 7ee4452f8c
commit 0472f44163
7 changed files with 74 additions and 27 deletions

View File

@ -57,6 +57,7 @@ class BaseConsumer:
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device() self.device = get_current_device()
self.lr_scheduler = None
def setup(self) -> None: def setup(self) -> None:
for i in range(self.num_producers): for i in range(self.num_producers):
@ -121,6 +122,8 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1
assert len(self.buffer) == 0 assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0: if (step + 1) % self.save_interval == 0:
if self.rank == 0: if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.") print(f"Start saving policy model at step {step + 1}.")

View File

@ -15,6 +15,7 @@ from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -34,10 +35,10 @@ class GRPOConsumer(BaseConsumer):
model_config, model_config,
plugin_config, plugin_config,
microbatch_size=1, microbatch_size=1,
num_generations=4, num_generations=8,
use_wandb=True, use_wandb=True,
generator_config=None, generate_config=None,
filter_range=None, training_config={},
): ):
super().__init__( super().__init__(
num_producers, num_producers,
@ -57,7 +58,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train() self.policy_model.train()
self.policy_model.gradient_checkpointing_enable() self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device)
@ -66,6 +67,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0 self.accum_count = 0
self.generate_config = generate_config
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -74,7 +76,7 @@ class GRPOConsumer(BaseConsumer):
self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations self.num_generations = num_generations
self.filter_range = filter_range self.filter_range = training_config.get("filter_range", None)
if self.filter_range is not None: if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values." assert len(self.filter_range) == 2, "Filter range should have 2 values."
@ -92,15 +94,21 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
self.global_step = 0 self.global_step = 0
if use_wandb and self.rank == 0: if use_wandb and self.rank == 0:
if "repetition_penalty" in generator_config: name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}"
else:
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}"
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
warmup_steps=0,
eta_min=0.1 * training_config.get("lr", 1e-6),
)
def setup(self): def setup(self):
super().setup() super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
self.reference_model, *_ = self.booster.boost(self.reference_model) self.reference_model, *_ = self.booster.boost(self.reference_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]: def step(self, step_idx: int, **kwargs) -> Optional[float]:
@ -133,7 +141,10 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
action_log_probs = calc_action_log_probs( action_log_probs = calc_action_log_probs(
policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config policy_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
) )
with torch.no_grad(): with torch.no_grad():
@ -142,7 +153,10 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
reference_action_log_probs = calc_action_log_probs( reference_action_log_probs = calc_action_log_probs(
reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config reference_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
) )
per_token_kl = ( per_token_kl = (
@ -161,22 +175,24 @@ class GRPOConsumer(BaseConsumer):
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations] # [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = ( loss_mask = (
None None
if self.filter_range is None if self.filter_range is None
else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1]) else torch.logical_and(
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
) )
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations] # [batch_size x num_generations]
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations] # [batch_size x num_generations]
advantages = (reward - reward_mean) / (reward_std + 1e-4) advantages = (reward - reward_mean) / (reward_std + 1e-4)
# Calculate Loss
loss, skip_update, _ = self.policy_loss_fn( loss, skip_update, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
old_action_log_probs, old_action_log_probs,

View File

@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
) )
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True) FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path") path = model_config.pop("path")
@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config = generate_config.copy() self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = 8 self.num_generations = num_generations
@torch.no_grad() @torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
class SGLangInferenceBackend(BaseInferenceBackend): class SGLangInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if sgl is None: if sgl is None:
raise ImportError("sglang is not installed") raise ImportError("sglang is not installed")
path = model_config.pop("path") path = model_config.pop("path")
@ -175,10 +187,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
) )
FORCE_GENERATE_CONFIG = dict( FORCE_GENERATE_CONFIG = dict(
logprobs=0, logprobs=0,
n=8,
) )
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if LLM is None: if LLM is None:
raise ImportError("vllm is not installed") raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
@ -186,9 +203,10 @@ class VLLMInferenceBackend(BaseInferenceBackend):
self.llm = LLM(model=path, **model_config) self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy() generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config) self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = self.FORCE_GENERATE_CONFIG["n"] self.num_generations = num_generations
@torch.no_grad() @torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:

View File

@ -42,6 +42,7 @@ def launch_distributed(
plugin_config: Dict[str, Any], plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None, tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers", inference_backend: str = "transformers",
num_generations: int = 8,
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: int = 29500, master_port: int = 29500,
core_algo: str = "GRPO", core_algo: str = "GRPO",
@ -76,6 +77,7 @@ def launch_distributed(
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size, microbatch_size=inference_microbatch_size,
backend=inference_backend, backend=inference_backend,
num_generations=num_generations,
) )
procs.append(producer) procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer = copy.deepcopy(generate_config)
@ -99,7 +101,8 @@ def launch_distributed(
plugin_config=plugin_config, plugin_config=plugin_config,
microbatch_size=train_microbatch_size, microbatch_size=train_microbatch_size,
generate_config=generate_config_consumer, generate_config=generate_config_consumer,
filter_range=[0.05, 9.0], training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
num_generations=num_generations,
) )
procs.append(consumer) procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in procs])

View File

@ -117,6 +117,12 @@ class BaseProducer:
None, self.num_producers, device=self.device, group_name="sync_model" None, self.num_producers, device=self.device, group_name="sync_model"
) )
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
# linear annealing for 1 episode, temperature from initial to 0.7
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
)
@ray.remote @ray.remote
@ -135,6 +141,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config=None, tokenizer_config=None,
microbatch_size=1, microbatch_size=1,
backend="transformers", backend="transformers",
num_generations: int = 8,
): ):
super().__init__( super().__init__(
producer_idx, producer_idx,
@ -150,7 +157,7 @@ class SimpleProducer(BaseProducer):
microbatch_size, microbatch_size,
backend, backend,
) )
self.model = self.backend_cls(model_config, generate_config, self.tokenizer) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
@torch.no_grad() @torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs): def rollout(self, input_ids, attention_mask, **kwargs):

View File

@ -22,7 +22,7 @@ if __name__ == "__main__":
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model) train_model_config = dict(path=args.model)
generate_config = dict(top_k=50, top_p=0.9, temperature=0.7) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers": if args.backend == "transformers":
inference_model_config.update( inference_model_config.update(

View File

@ -387,7 +387,7 @@ def dist_log_prob(
dtype=dtype, dtype=dtype,
) )
else: else:
log_prob = log_softmax(logits) log_prob = log_softmax(logits, dim=-1)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
return log_prob return log_prob