diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 646853829..01b360cf1 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -13,8 +13,8 @@ from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.initialize import launch from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device from colossalai.shardformer.policies.auto_policy import get_autopolicy +from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .utils import bind_batch, post_recv, unbind_batch diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 1ca40bf65..eaf43cb8c 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -90,13 +90,13 @@ class TransformersInferenceBackend(BaseInferenceBackend): attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0) attention_mask = torch.cat((attention_mask, action_mask), dim=1) - + data = { "input_ids": out.sequences, "attention_mask": attention_mask, "action_log_probs": action_log_probs, "action_mask": action_mask, - "response_idx": response_idx + "response_idx": response_idx, } return data diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index bb492a047..5fe7c916f 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -7,11 +7,7 @@ from .grpo_consumer import GRPOConsumer from .ppo_consumer import PPOConsumer from .producer import SimpleProducer -ALGO_MAP = { - "Simple": SimpleConsumer, - "GRPO": GRPOConsumer, - "PPO": PPOConsumer -} +ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "PPO": PPOConsumer} def get_jsonl_size_fast(path: str) -> int: diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index e63209f34..400549a0e 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -72,4 +72,4 @@ class ValueLoss(nn.Module): loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask)) else: loss = torch.mean(torch.max(surr1, surr2)) - return 0.5 * loss \ No newline at end of file + return 0.5 * loss diff --git a/applications/ColossalChat/coati/distributed/ppo_consumer.py b/applications/ColossalChat/coati/distributed/ppo_consumer.py index d7d80149c..7f865ce23 100644 --- a/applications/ColossalChat/coati/distributed/ppo_consumer.py +++ b/applications/ColossalChat/coati/distributed/ppo_consumer.py @@ -9,8 +9,8 @@ from coati.distributed.loss import PolicyLoss, ValueLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo -from coati.trainer.utils import all_reduce_mean from coati.models import Critic, disable_dropout +from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.optimizer import HybridAdam @@ -33,9 +33,9 @@ class PPOConsumer(BaseConsumer): plugin_config, microbatch_size=1, num_generations=1, - gamma:float=1.0, - lam:float=0.95, - kl_coef:float=0.05, + gamma: float = 1.0, + lam: float = 0.95, + kl_coef: float = 0.05, use_wandb=True, ): super().__init__( @@ -55,7 +55,7 @@ class PPOConsumer(BaseConsumer): self.gamma = gamma self.lam = lam self.kl_coef = kl_coef - + path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() @@ -63,14 +63,14 @@ class PPOConsumer(BaseConsumer): self.critic_model = Critic(path, **model_config) self.critic_model.model.gradient_checkpointing_enable() self.critic_model.train() - + # Disable dropout disable_dropout(self.policy_model) disable_dropout(self.critic_model) - + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6) - + self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) @@ -152,15 +152,13 @@ class PPOConsumer(BaseConsumer): input_ids=data["input_ids"], attention_mask=data["attention_mask"], ) - value = value[:, -num_action -1: -1] * action_mask + value = value[:, -num_action - 1 : -1] * action_mask - r = self.reward_model( - data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] - ) + r = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]) reward, kl = compute_reward_ppo( r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask ) - + # Calculate advantages # reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0 lastgaelam = 0 @@ -172,7 +170,7 @@ class PPOConsumer(BaseConsumer): advantage_reversed.append(lastgaelam) advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask advantage = advantage.detach() - + # KL divergence for logging per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) @@ -189,7 +187,7 @@ class PPOConsumer(BaseConsumer): 0, # kl is already included in the advantage action_mask, ) - + # Critic Loss # Hack: use the current value to approximate the old value, should be old value mathematically critic_loss = self.critic_loss_fn( @@ -236,7 +234,9 @@ class PPOConsumer(BaseConsumer): data["input_ids"][i], skip_special_tokens=True ) response_reward_for_logging = r[i] - print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}") + print( + f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}" + ) self.wandb_run.log( { "train/loss": self.accum_loss.item() / self.accum_count, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index d1d407a31..eb0621180 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -94,11 +94,11 @@ class BaseProducer: print( f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}" ) - for episode in range(self.num_episodes): + for episode in range(self.num_episodes): self.dataloader.sampler.set_epoch(episode) for i, batch in enumerate(self.dataloader): if i >= num_valid_microbatches: - break + break outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs = pre_send(outputs) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 837f48218..7cbcd66ca 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -100,6 +100,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch mean = tensor / (mask_sum + 1e-8) return mean + def compute_reward_ppo( r: Union[torch.Tensor, float], kl_coef: float, @@ -125,4 +126,4 @@ def compute_reward_ppo( assert action_mask[i].sum() > 0 reward[i, : action_mask[i].sum()] += r_clip[i] reward[i, action_mask[i].sum() :] *= 0 - return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask \ No newline at end of file + return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index b39ddc4d0..4b9bd23c4 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -28,7 +28,7 @@ if __name__ == "__main__": top_p=0.8, ) - if args.backend == "transformers": + if args.backend == "transformers": inference_model_config.update( dict( use_flash_attention_2=True, @@ -43,13 +43,7 @@ if __name__ == "__main__": ) ) generate_config.update( - dict( - max_length=768, - do_sample=True, - max_new_tokens=None, - early_stopping=False, - stop_strings=[""] - ) + dict(max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, stop_strings=[""]) ) elif args.backend == "vllm": inference_model_config.update(