[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-07 10:43:01 +00:00
parent c8e13a9403
commit 6e096362ef
8 changed files with 27 additions and 36 deletions

View File

@ -13,8 +13,8 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch from .utils import bind_batch, post_recv, unbind_batch

View File

@ -90,13 +90,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0) attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
attention_mask = torch.cat((attention_mask, action_mask), dim=1) attention_mask = torch.cat((attention_mask, action_mask), dim=1)
data = { data = {
"input_ids": out.sequences, "input_ids": out.sequences,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"action_log_probs": action_log_probs, "action_log_probs": action_log_probs,
"action_mask": action_mask, "action_mask": action_mask,
"response_idx": response_idx "response_idx": response_idx,
} }
return data return data

View File

@ -7,11 +7,7 @@ from .grpo_consumer import GRPOConsumer
from .ppo_consumer import PPOConsumer from .ppo_consumer import PPOConsumer
from .producer import SimpleProducer from .producer import SimpleProducer
ALGO_MAP = { ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "PPO": PPOConsumer}
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
"PPO": PPOConsumer
}
def get_jsonl_size_fast(path: str) -> int: def get_jsonl_size_fast(path: str) -> int:

View File

@ -72,4 +72,4 @@ class ValueLoss(nn.Module):
loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask)) loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask))
else: else:
loss = torch.mean(torch.max(surr1, surr2)) loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss return 0.5 * loss

View File

@ -9,8 +9,8 @@ from coati.distributed.loss import PolicyLoss, ValueLoss
from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
from coati.trainer.utils import all_reduce_mean
from coati.models import Critic, disable_dropout from coati.models import Critic, disable_dropout
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -33,9 +33,9 @@ class PPOConsumer(BaseConsumer):
plugin_config, plugin_config,
microbatch_size=1, microbatch_size=1,
num_generations=1, num_generations=1,
gamma:float=1.0, gamma: float = 1.0,
lam:float=0.95, lam: float = 0.95,
kl_coef:float=0.05, kl_coef: float = 0.05,
use_wandb=True, use_wandb=True,
): ):
super().__init__( super().__init__(
@ -55,7 +55,7 @@ class PPOConsumer(BaseConsumer):
self.gamma = gamma self.gamma = gamma
self.lam = lam self.lam = lam
self.kl_coef = kl_coef self.kl_coef = kl_coef
path = model_config.pop("path") path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train() self.policy_model.train()
@ -63,14 +63,14 @@ class PPOConsumer(BaseConsumer):
self.critic_model = Critic(path, **model_config) self.critic_model = Critic(path, **model_config)
self.critic_model.model.gradient_checkpointing_enable() self.critic_model.model.gradient_checkpointing_enable()
self.critic_model.train() self.critic_model.train()
# Disable dropout # Disable dropout
disable_dropout(self.policy_model) disable_dropout(self.policy_model)
disable_dropout(self.critic_model) disable_dropout(self.critic_model)
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6) self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device)
@ -152,15 +152,13 @@ class PPOConsumer(BaseConsumer):
input_ids=data["input_ids"], input_ids=data["input_ids"],
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
) )
value = value[:, -num_action -1: -1] * action_mask value = value[:, -num_action - 1 : -1] * action_mask
r = self.reward_model( r = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward, kl = compute_reward_ppo( reward, kl = compute_reward_ppo(
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
) )
# Calculate advantages # Calculate advantages
# reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0 # reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
lastgaelam = 0 lastgaelam = 0
@ -172,7 +170,7 @@ class PPOConsumer(BaseConsumer):
advantage_reversed.append(lastgaelam) advantage_reversed.append(lastgaelam)
advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
advantage = advantage.detach() advantage = advantage.detach()
# KL divergence for logging # KL divergence for logging
per_token_kl = ( per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs) torch.exp(reference_action_log_probs - action_log_probs)
@ -189,7 +187,7 @@ class PPOConsumer(BaseConsumer):
0, # kl is already included in the advantage 0, # kl is already included in the advantage
action_mask, action_mask,
) )
# Critic Loss # Critic Loss
# Hack: use the current value to approximate the old value, should be old value mathematically # Hack: use the current value to approximate the old value, should be old value mathematically
critic_loss = self.critic_loss_fn( critic_loss = self.critic_loss_fn(
@ -236,7 +234,9 @@ class PPOConsumer(BaseConsumer):
data["input_ids"][i], skip_special_tokens=True data["input_ids"][i], skip_special_tokens=True
) )
response_reward_for_logging = r[i] response_reward_for_logging = r[i]
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}") print(
f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}"
)
self.wandb_run.log( self.wandb_run.log(
{ {
"train/loss": self.accum_loss.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count,

View File

@ -94,11 +94,11 @@ class BaseProducer:
print( print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}" f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
) )
for episode in range(self.num_episodes): for episode in range(self.num_episodes):
self.dataloader.sampler.set_epoch(episode) self.dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.dataloader): for i, batch in enumerate(self.dataloader):
if i >= num_valid_microbatches: if i >= num_valid_microbatches:
break break
outputs = self.rollout(**batch) outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs) outputs = pre_send(outputs)

View File

@ -100,6 +100,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mean = tensor / (mask_sum + 1e-8) mean = tensor / (mask_sum + 1e-8)
return mean return mean
def compute_reward_ppo( def compute_reward_ppo(
r: Union[torch.Tensor, float], r: Union[torch.Tensor, float],
kl_coef: float, kl_coef: float,
@ -125,4 +126,4 @@ def compute_reward_ppo(
assert action_mask[i].sum() > 0 assert action_mask[i].sum() > 0
reward[i, : action_mask[i].sum()] += r_clip[i] reward[i, : action_mask[i].sum()] += r_clip[i]
reward[i, action_mask[i].sum() :] *= 0 reward[i, action_mask[i].sum() :] *= 0
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask

View File

@ -28,7 +28,7 @@ if __name__ == "__main__":
top_p=0.8, top_p=0.8,
) )
if args.backend == "transformers": if args.backend == "transformers":
inference_model_config.update( inference_model_config.update(
dict( dict(
use_flash_attention_2=True, use_flash_attention_2=True,
@ -43,13 +43,7 @@ if __name__ == "__main__":
) )
) )
generate_config.update( generate_config.update(
dict( dict(max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, stop_strings=["</answer>"])
max_length=768,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
stop_strings=["</answer>"]
)
) )
elif args.backend == "vllm": elif args.backend == "vllm":
inference_model_config.update( inference_model_config.update(