mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 21:01:54 +00:00
fix logprob, add filtering, temperature annealing, lr descent
This commit is contained in:
parent
7ee4452f8c
commit
0472f44163
@ -57,6 +57,7 @@ class BaseConsumer:
|
|||||||
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
|
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
|
||||||
|
|
||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
for i in range(self.num_producers):
|
for i in range(self.num_producers):
|
||||||
@ -121,6 +122,8 @@ class BaseConsumer:
|
|||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
i += 1
|
i += 1
|
||||||
assert len(self.buffer) == 0
|
assert len(self.buffer) == 0
|
||||||
|
if self.lr_scheduler is not None:
|
||||||
|
self.lr_scheduler.step()
|
||||||
if (step + 1) % self.save_interval == 0:
|
if (step + 1) % self.save_interval == 0:
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
print(f"Start saving policy model at step {step + 1}.")
|
print(f"Start saving policy model at step {step + 1}.")
|
||||||
|
@ -15,6 +15,7 @@ from coati.distributed.utils import calc_action_log_probs
|
|||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
|
||||||
@ -34,10 +35,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
num_generations=4,
|
num_generations=8,
|
||||||
use_wandb=True,
|
use_wandb=True,
|
||||||
generator_config=None,
|
generate_config=None,
|
||||||
filter_range=None,
|
training_config={},
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
@ -57,7 +58,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
self.policy_model.train()
|
self.policy_model.train()
|
||||||
self.policy_model.gradient_checkpointing_enable()
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
self.accum_reward = torch.zeros(1, device=self.device)
|
self.accum_reward = torch.zeros(1, device=self.device)
|
||||||
self.accum_kl = torch.zeros(1, device=self.device)
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
@ -66,6 +67,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
|
self.generate_config = generate_config
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
@ -74,7 +76,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
self.pad_token_id = self.tokenizer.pad_token_id
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
self.filter_range = filter_range
|
self.filter_range = training_config.get("filter_range", None)
|
||||||
if self.filter_range is not None:
|
if self.filter_range is not None:
|
||||||
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
||||||
|
|
||||||
@ -92,15 +94,21 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_loss_fn = PolicyLoss()
|
self.policy_loss_fn = PolicyLoss()
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
if use_wandb and self.rank == 0:
|
if use_wandb and self.rank == 0:
|
||||||
if "repetition_penalty" in generator_config:
|
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||||
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}"
|
|
||||||
else:
|
|
||||||
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}"
|
|
||||||
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
|
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
|
||||||
|
|
||||||
|
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
|
||||||
|
warmup_steps=0,
|
||||||
|
eta_min=0.1 * training_config.get("lr", 1e-6),
|
||||||
|
)
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
|
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||||
|
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||||
|
)
|
||||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||||
|
|
||||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||||
@ -133,7 +141,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
attention_mask=data["attention_mask"],
|
attention_mask=data["attention_mask"],
|
||||||
)["logits"]
|
)["logits"]
|
||||||
action_log_probs = calc_action_log_probs(
|
action_log_probs = calc_action_log_probs(
|
||||||
policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
|
policy_model_logits / self.generate_config["temperature"],
|
||||||
|
data["input_ids"],
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -142,7 +153,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
attention_mask=data["attention_mask"],
|
attention_mask=data["attention_mask"],
|
||||||
)["logits"]
|
)["logits"]
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
reference_action_log_probs = calc_action_log_probs(
|
||||||
reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
|
data["input_ids"],
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
@ -161,22 +175,24 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||||
|
|
||||||
# [batch_size, num_generations]
|
# [batch_size, num_generations]
|
||||||
|
|
||||||
|
group_reward = reward.view(-1, self.num_generations)
|
||||||
|
reward_mean = group_reward.mean(dim=1)
|
||||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||||
loss_mask = (
|
loss_mask = (
|
||||||
None
|
None
|
||||||
if self.filter_range is None
|
if self.filter_range is None
|
||||||
else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1])
|
else torch.logical_and(
|
||||||
|
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
|
||||||
|
).repeat_interleave(self.num_generations, dim=0)
|
||||||
)
|
)
|
||||||
group_reward = reward.view(-1, self.num_generations)
|
|
||||||
reward_mean = group_reward.mean(dim=1)
|
|
||||||
|
|
||||||
# [batch_size x num_generations]
|
# [batch_size x num_generations]
|
||||||
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
# [batch_size x num_generations]
|
# [batch_size x num_generations]
|
||||||
advantages = (reward - reward_mean) / (reward_std + 1e-4)
|
advantages = (reward - reward_mean) / (reward_std + 1e-4)
|
||||||
|
|
||||||
# Calculate Loss
|
|
||||||
loss, skip_update, _ = self.policy_loss_fn(
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
old_action_log_probs,
|
old_action_log_probs,
|
||||||
|
@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||||||
)
|
)
|
||||||
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
|
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
|
||||||
|
|
||||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: Dict[str, Any],
|
||||||
|
generate_config: Dict[str, Any],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
num_generations: int = 8,
|
||||||
|
):
|
||||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||||
model_config.update(self.FORCE_MODEL_CONFIG)
|
model_config.update(self.FORCE_MODEL_CONFIG)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||||||
self.generate_config = generate_config.copy()
|
self.generate_config = generate_config.copy()
|
||||||
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_generations = 8
|
self.num_generations = num_generations
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||||||
|
|
||||||
|
|
||||||
class SGLangInferenceBackend(BaseInferenceBackend):
|
class SGLangInferenceBackend(BaseInferenceBackend):
|
||||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: Dict[str, Any],
|
||||||
|
generate_config: Dict[str, Any],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
num_generations: int = 8,
|
||||||
|
):
|
||||||
if sgl is None:
|
if sgl is None:
|
||||||
raise ImportError("sglang is not installed")
|
raise ImportError("sglang is not installed")
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
@ -175,10 +187,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
)
|
)
|
||||||
FORCE_GENERATE_CONFIG = dict(
|
FORCE_GENERATE_CONFIG = dict(
|
||||||
logprobs=0,
|
logprobs=0,
|
||||||
n=8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: Dict[str, Any],
|
||||||
|
generate_config: Dict[str, Any],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
num_generations: int = 8,
|
||||||
|
):
|
||||||
if LLM is None:
|
if LLM is None:
|
||||||
raise ImportError("vllm is not installed")
|
raise ImportError("vllm is not installed")
|
||||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||||
@ -186,9 +203,10 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
self.llm = LLM(model=path, **model_config)
|
self.llm = LLM(model=path, **model_config)
|
||||||
generate_config = generate_config.copy()
|
generate_config = generate_config.copy()
|
||||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
|
generate_config.update({"n": num_generations})
|
||||||
self.generate_config = SamplingParams(**generate_config)
|
self.generate_config = SamplingParams(**generate_config)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
|
self.num_generations = num_generations
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
|
@ -42,6 +42,7 @@ def launch_distributed(
|
|||||||
plugin_config: Dict[str, Any],
|
plugin_config: Dict[str, Any],
|
||||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||||
inference_backend: str = "transformers",
|
inference_backend: str = "transformers",
|
||||||
|
num_generations: int = 8,
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: int = 29500,
|
master_port: int = 29500,
|
||||||
core_algo: str = "GRPO",
|
core_algo: str = "GRPO",
|
||||||
@ -76,6 +77,7 @@ def launch_distributed(
|
|||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
microbatch_size=inference_microbatch_size,
|
microbatch_size=inference_microbatch_size,
|
||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
|
num_generations=num_generations,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
procs.append(producer)
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
@ -99,7 +101,8 @@ def launch_distributed(
|
|||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
microbatch_size=train_microbatch_size,
|
microbatch_size=train_microbatch_size,
|
||||||
generate_config=generate_config_consumer,
|
generate_config=generate_config_consumer,
|
||||||
filter_range=[0.05, 9.0],
|
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
|
||||||
|
num_generations=num_generations,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -117,6 +117,12 @@ class BaseProducer:
|
|||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||||
)
|
)
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
|
# linear annealing for 1 episode, temperature from initial to 0.7
|
||||||
|
if episode <= 0:
|
||||||
|
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||||
|
self.model.generate_config.temperature = (
|
||||||
|
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
@ -135,6 +141,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
tokenizer_config=None,
|
tokenizer_config=None,
|
||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
backend="transformers",
|
backend="transformers",
|
||||||
|
num_generations: int = 8,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
producer_idx,
|
producer_idx,
|
||||||
@ -150,7 +157,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
microbatch_size,
|
microbatch_size,
|
||||||
backend,
|
backend,
|
||||||
)
|
)
|
||||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
|
@ -22,7 +22,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model)
|
train_model_config = dict(path=args.model)
|
||||||
generate_config = dict(top_k=50, top_p=0.9, temperature=0.7)
|
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
|
@ -387,7 +387,7 @@ def dist_log_prob(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log_prob = log_softmax(logits)
|
log_prob = log_softmax(logits, dim=-1)
|
||||||
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
|
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||||
|
|
||||||
return log_prob
|
return log_prob
|
||||||
|
Loading…
Reference in New Issue
Block a user