[Distributed RLHF] Integration of PP (#6257)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
YeAnbang 2025-04-09 13:23:24 +08:00 committed by GitHub
parent 50153005b4
commit ed43a4be04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 263 additions and 116 deletions

1
.gitignore vendored
View File

@ -164,3 +164,4 @@ coverage.xml
applications/ColossalChat/logs applications/ColossalChat/logs
applications/ColossalChat/tests/logs applications/ColossalChat/tests/logs
applications/ColossalChat/wandb applications/ColossalChat/wandb
applications/ColossalChat/model

View File

@ -54,7 +54,6 @@ class BaseConsumer:
self.model_config = model_config self.model_config = model_config
self.plugin_config = plugin_config self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device() self.device = get_current_device()
self.lr_scheduler = None self.lr_scheduler = None
@ -95,7 +94,6 @@ class BaseConsumer:
i = 0 i = 0
for _ in range(self.num_recv_per_update): for _ in range(self.num_recv_per_update):
# receive data from producers # receive data from producers
for r in range(self.num_producers): for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend( self.buffer.extend(

View File

@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer):
use_wandb=True, use_wandb=True,
generate_config=None, generate_config=None,
training_config={}, training_config={},
project_name=None,
): ):
super().__init__( super().__init__(
num_producers, num_producers,
@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_count = 0 self.accum_count = 0
self.generate_config = generate_config self.generate_config = generate_config
self.training_config = training_config self.training_config = training_config
self.project_name = project_name
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -94,9 +96,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
self.global_step = 0 self.global_step = 0
if use_wandb and self.rank == 0: self.use_wandb = use_wandb
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
self.lr_scheduler = CosineAnnealingWarmupLR( self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer, optimizer=self.optimizer,
@ -107,10 +107,19 @@ class GRPOConsumer(BaseConsumer):
def setup(self): def setup(self):
super().setup() super().setup()
if self.use_wandb and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
) )
self.reference_model, *_ = self.booster.boost(self.reference_model) self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, **kwargs) -> Optional[float]: def step(self, step_idx: int, **kwargs) -> Optional[float]:
""" """
@ -168,6 +177,7 @@ class GRPOConsumer(BaseConsumer):
).repeat_interleave(self.num_generations, dim=0) ).repeat_interleave(self.num_generations, dim=0)
) )
mean_kl, mean_loss = [], [] mean_kl, mean_loss = [], []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][ input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
@ -186,112 +196,210 @@ class GRPOConsumer(BaseConsumer):
advantages_forward_micro_batch = advantages[ advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
] ]
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
with torch.no_grad(): if self.plugin.pp_size > 1:
reference_model_logits = self.reference_model( # Support training with PP.
with torch.no_grad():
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
{
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
}
]
),
self.reference_model,
criterion=lambda outputs, inputs: torch.tensor(
[0.0], device=action_mask.device
), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = None
data_policy_forward = {
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
"action_mask": action_mask_forward_micro_batch,
"reference_action_log_probs": reference_action_log_probs,
"advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch,
"source": self.rank,
}
kl = []
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = calc_action_log_probs(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
- (inputs["reference_action_log_probs"] - action_log_probs)
- 1
)
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
inputs["action_mask"], dim=-1
)
kl.append(appox_kl.mean())
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
inputs["action_mask"],
loss_mask=inputs["loss_mask"],
)
return loss
policy_model_outputs = self.booster.execute_pipeline(
iter([data_policy_forward]),
self.policy_model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
return_outputs=True,
)
loss = policy_model_outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
if len(kl) > 0:
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
mean_kl.append(kl)
loss = all_reduce_mean(loss, self.plugin)
mean_loss.append(loss.data)
else:
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch, input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch,
).logits ).logits
reference_action_log_probs = calc_action_log_probs( action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"], policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch, input_ids_forward_micro_batch,
num_action, num_action,
self.plugin.shard_config, self.plugin.shard_config,
) )
per_token_kl = ( with torch.no_grad():
torch.exp(reference_action_log_probs - action_log_probs) reference_model_logits = self.reference_model(
- (reference_action_log_probs - action_log_probs) input_ids=input_ids_forward_micro_batch,
- 1 attention_mask=attention_mask_forward_micro_batch,
) ).logits
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( reference_action_log_probs = calc_action_log_probs(
action_mask_forward_micro_batch, dim=-1 reference_model_logits / self.generate_config["temperature"],
) input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
loss, skip_update, _ = self.policy_loss_fn( loss, skip_update, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
old_action_log_probs, old_action_log_probs,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl, per_token_kl,
action_mask_forward_micro_batch, action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch, loss_mask=loss_mask_forward_micro_batch,
) )
if not skip_update: if not skip_update:
self.booster.backward(loss, self.optimizer) self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin) loss = all_reduce_mean(loss, self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin) kl = all_reduce_mean(kl.mean(), self.plugin)
# Calculate accumulate value. # Calculate accumulate value.
mean_kl.append(kl.data) mean_kl.append(kl.data)
mean_loss.append(loss.data) mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or (
reward = all_reduce_mean(reward.mean(), self.plugin) self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
format_reward = all_reduce_mean(format_reward.mean(), self.plugin) ):
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) reward = all_reduce_mean(reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin) acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) advantages = all_reduce_mean(advantages.mean(), self.plugin)
self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) response_length = all_reduce_mean(response_length.mean(), self.plugin)
self.accum_reward.add_(reward.data) self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_format_reward.add_(format_reward.data) self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_acc_reward.add_(acc_reward.data) self.accum_reward.add_(reward.data)
self.accum_advantages.add_(advantages.data) self.accum_format_reward.add_(format_reward.data)
self.accum_response_length.add_(response_length.data) self.accum_acc_reward.add_(acc_reward.data)
self.accum_count += 1 self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1
if need_update: if need_update:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or (
if self.rank == 0: self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
print( ):
"Loss:", loss_scalar = self.accum_loss.item()
self.accum_loss.item() / self.accum_count, if (not self.plugin.pp_size > 1 and self.rank == 0) or (
"\nReward:", self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
self.accum_reward.item() / self.accum_count, ):
"\nFormat Reward:", print(
self.accum_format_reward.item() / self.accum_count, "Loss:",
"\nAcc Reward:", self.accum_loss.item() / self.accum_count,
self.accum_acc_reward.item() / self.accum_count, "\nReward:",
"\nKL:", self.accum_reward.item() / self.accum_count,
self.accum_kl.item() / self.accum_count, "\nFormat Reward:",
"\nAdvantages:", self.accum_format_reward.item() / self.accum_count,
self.accum_advantages.item() / self.accum_count, "\nAcc Reward:",
"\nResponse Length:", self.accum_acc_reward.item() / self.accum_count,
self.accum_response_length.item() / self.accum_count, "\nKL:",
) self.accum_kl.item() / self.accum_count,
self.wandb_run.log( "\nAdvantages:",
{ self.accum_advantages.item() / self.accum_count,
"metrics/reward": self.accum_reward.item() / self.accum_count, "\nResponse Length:",
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count, self.accum_response_length.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, )
"metrics/response_length": self.accum_response_length.item() / self.accum_count, self.wandb_run.log(
"train/loss": self.accum_loss.item() / self.accum_count, {
"train/kl": self.accum_kl.item() / self.accum_count, "metrics/reward": self.accum_reward.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count, "metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0], "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0], "metrics/response_length": self.accum_response_length.item() / self.accum_count,
} "train/loss": self.accum_loss.item() / self.accum_count,
) "train/kl": self.accum_kl.item() / self.accum_count,
self.accum_loss.zero_() "train/advantages": self.accum_advantages.item() / self.accum_count,
self.accum_reward.zero_() "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
self.accum_acc_reward.zero_() "rollout/temperature": data["temperature"].cpu().numpy()[0][0],
self.accum_format_reward.zero_() }
self.accum_kl.zero_() )
self.accum_advantages.zero_() self.accum_loss.zero_()
self.accum_response_length.zero_() self.accum_reward.zero_()
self.accum_acc_reward.zero_()
self.accum_format_reward.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0 self.accum_count = 0
return loss_scalar return loss_scalar
def state_dict(self): def state_dict(self):
self.policy_model._force_wait_all_gather() self.policy_model._force_wait_all_gather()

View File

@ -47,6 +47,7 @@ def launch_distributed(
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: int = 29500, master_port: int = 29500,
core_algo: str = "GRPO", core_algo: str = "GRPO",
project_name: Optional[str] = None,
): ):
if core_algo not in ALGO_MAP: if core_algo not in ALGO_MAP:
@ -108,6 +109,7 @@ def launch_distributed(
"train_microbatch_size": train_microbatch_size, "train_microbatch_size": train_microbatch_size,
}, },
num_generations=num_generations, num_generations=num_generations,
project_name=project_name,
) )
procs.append(consumer) procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in procs])

View File

@ -10,13 +10,44 @@ if __name__ == "__main__":
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8) parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) parser.add_argument(
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) "-ibs",
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) "--inference-batch-size",
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) type=int,
parser.add_argument("-b", "--backend", type=str, default="transformers") default=64,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
)
parser.add_argument(
"-imbs",
"--inference-microbatch-size",
type=int,
default=8,
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
)
parser.add_argument(
"-tbs",
"--train-batch-size",
type=int,
default=32,
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
)
parser.add_argument(
"-tMbs",
"--train-minibatch-size",
type=int,
default=1,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
)
parser.add_argument(
"-tmbs",
"--train-microbatch-size",
type=int,
default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
)
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args() args = parser.parse_args()
@ -29,11 +60,7 @@ if __name__ == "__main__":
ray.init(address="local", namespace="ray-example") ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict( train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
path=args.model,
# use_flash_attention_2=True,
# use_cache=False
)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers": if args.backend == "transformers":
@ -91,9 +118,17 @@ if __name__ == "__main__":
generate_config=generate_config, generate_config=generate_config,
num_generations=args.num_generations, num_generations=args.num_generations,
train_model_config=train_model_config, train_model_config=train_model_config,
plugin_config={}, # plugin_config={}, # for zero
plugin_config={
"pp_size": 2,
"tp_size": 1,
"microbatch_size": args.train_microbatch_size // 2,
"zero_stage": 0,
"max_norm": 1.0,
}, # for pp
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",
master_port=29505, master_port=29506,
core_algo=args.algo, core_algo=args.algo,
project_name=args.project,
) )

View File

@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
# run with gradients accumulation # run with gradients accumulation
if model.require_grad_sync == False or ( if (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False not torch.is_grad_enabled()
or model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
): ):
return outputs return outputs

View File

@ -284,6 +284,7 @@ class Qwen2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
**kwargs,
): ):
r""" r"""
Args: Args: