mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-12 23:00:16 +00:00
[Distributed RLHF] Integration of PP (#6257)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
50153005b4
commit
ed43a4be04
1
.gitignore
vendored
1
.gitignore
vendored
@ -164,3 +164,4 @@ coverage.xml
|
|||||||
applications/ColossalChat/logs
|
applications/ColossalChat/logs
|
||||||
applications/ColossalChat/tests/logs
|
applications/ColossalChat/tests/logs
|
||||||
applications/ColossalChat/wandb
|
applications/ColossalChat/wandb
|
||||||
|
applications/ColossalChat/model
|
||||||
|
@ -54,7 +54,6 @@ class BaseConsumer:
|
|||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.plugin_config = plugin_config
|
self.plugin_config = plugin_config
|
||||||
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
|
|
||||||
|
|
||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
@ -95,7 +94,6 @@ class BaseConsumer:
|
|||||||
i = 0
|
i = 0
|
||||||
for _ in range(self.num_recv_per_update):
|
for _ in range(self.num_recv_per_update):
|
||||||
# receive data from producers
|
# receive data from producers
|
||||||
|
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||||
self.buffer.extend(
|
self.buffer.extend(
|
||||||
|
@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
use_wandb=True,
|
use_wandb=True,
|
||||||
generate_config=None,
|
generate_config=None,
|
||||||
training_config={},
|
training_config={},
|
||||||
|
project_name=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
|
self.project_name = project_name
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
@ -94,9 +96,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
self.policy_loss_fn = PolicyLoss()
|
self.policy_loss_fn = PolicyLoss()
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
if use_wandb and self.rank == 0:
|
self.use_wandb = use_wandb
|
||||||
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
|
||||||
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
|
|
||||||
|
|
||||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
@ -107,10 +107,19 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
|
if self.use_wandb and (
|
||||||
|
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||||
|
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
|
||||||
|
):
|
||||||
|
# Initialize wandb.
|
||||||
|
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||||
|
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
|
||||||
|
|
||||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||||
)
|
)
|
||||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||||
|
self.plugin.logger.set_level("ERROR")
|
||||||
|
|
||||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
@ -168,6 +177,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
).repeat_interleave(self.num_generations, dim=0)
|
).repeat_interleave(self.num_generations, dim=0)
|
||||||
)
|
)
|
||||||
mean_kl, mean_loss = [], []
|
mean_kl, mean_loss = [], []
|
||||||
|
|
||||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||||
input_ids_forward_micro_batch = data["input_ids"][
|
input_ids_forward_micro_batch = data["input_ids"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||||
@ -186,112 +196,210 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
advantages_forward_micro_batch = advantages[
|
advantages_forward_micro_batch = advantages[
|
||||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||||
]
|
]
|
||||||
policy_model_logits = self.policy_model(
|
|
||||||
input_ids=input_ids_forward_micro_batch,
|
|
||||||
attention_mask=attention_mask_forward_micro_batch,
|
|
||||||
).logits
|
|
||||||
action_log_probs = calc_action_log_probs(
|
|
||||||
policy_model_logits / self.generate_config["temperature"],
|
|
||||||
input_ids_forward_micro_batch,
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
if self.plugin.pp_size > 1:
|
||||||
reference_model_logits = self.reference_model(
|
# Support training with PP.
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
reference_model_outputs = self.booster.execute_pipeline(
|
||||||
|
iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"input_ids": input_ids_forward_micro_batch,
|
||||||
|
"attention_mask": attention_mask_forward_micro_batch,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
self.reference_model,
|
||||||
|
criterion=lambda outputs, inputs: torch.tensor(
|
||||||
|
[0.0], device=action_mask.device
|
||||||
|
), # dummy criterion
|
||||||
|
optimizer=None,
|
||||||
|
return_loss=False,
|
||||||
|
return_outputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
||||||
|
reference_action_log_probs = calc_action_log_probs(
|
||||||
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
|
input_ids_forward_micro_batch,
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Dummy reference logprobs for data iterator.
|
||||||
|
reference_action_log_probs = None
|
||||||
|
|
||||||
|
data_policy_forward = {
|
||||||
|
"input_ids": input_ids_forward_micro_batch,
|
||||||
|
"attention_mask": attention_mask_forward_micro_batch,
|
||||||
|
"action_mask": action_mask_forward_micro_batch,
|
||||||
|
"reference_action_log_probs": reference_action_log_probs,
|
||||||
|
"advantages": advantages_forward_micro_batch,
|
||||||
|
"loss_mask": loss_mask_forward_micro_batch,
|
||||||
|
"source": self.rank,
|
||||||
|
}
|
||||||
|
|
||||||
|
kl = []
|
||||||
|
|
||||||
|
def _criterion(outputs, inputs):
|
||||||
|
action_logits = outputs.logits
|
||||||
|
action_log_probs = calc_action_log_probs(
|
||||||
|
action_logits / self.generate_config["temperature"],
|
||||||
|
inputs["input_ids"],
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)
|
||||||
|
per_token_kl = (
|
||||||
|
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
|
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
||||||
|
inputs["action_mask"], dim=-1
|
||||||
|
)
|
||||||
|
kl.append(appox_kl.mean())
|
||||||
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
|
action_log_probs,
|
||||||
|
action_log_probs,
|
||||||
|
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
|
per_token_kl,
|
||||||
|
inputs["action_mask"],
|
||||||
|
loss_mask=inputs["loss_mask"],
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
policy_model_outputs = self.booster.execute_pipeline(
|
||||||
|
iter([data_policy_forward]),
|
||||||
|
self.policy_model,
|
||||||
|
criterion=_criterion,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
return_outputs=True,
|
||||||
|
)
|
||||||
|
loss = policy_model_outputs["loss"]
|
||||||
|
|
||||||
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
if len(kl) > 0:
|
||||||
|
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
|
||||||
|
mean_kl.append(kl)
|
||||||
|
loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
mean_loss.append(loss.data)
|
||||||
|
else:
|
||||||
|
|
||||||
|
policy_model_logits = self.policy_model(
|
||||||
input_ids=input_ids_forward_micro_batch,
|
input_ids=input_ids_forward_micro_batch,
|
||||||
attention_mask=attention_mask_forward_micro_batch,
|
attention_mask=attention_mask_forward_micro_batch,
|
||||||
).logits
|
).logits
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
action_log_probs = calc_action_log_probs(
|
||||||
reference_model_logits / self.generate_config["temperature"],
|
policy_model_logits / self.generate_config["temperature"],
|
||||||
input_ids_forward_micro_batch,
|
input_ids_forward_micro_batch,
|
||||||
num_action,
|
num_action,
|
||||||
self.plugin.shard_config,
|
self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
per_token_kl = (
|
with torch.no_grad():
|
||||||
torch.exp(reference_action_log_probs - action_log_probs)
|
reference_model_logits = self.reference_model(
|
||||||
- (reference_action_log_probs - action_log_probs)
|
input_ids=input_ids_forward_micro_batch,
|
||||||
- 1
|
attention_mask=attention_mask_forward_micro_batch,
|
||||||
)
|
).logits
|
||||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
reference_action_log_probs = calc_action_log_probs(
|
||||||
action_mask_forward_micro_batch, dim=-1
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
)
|
input_ids_forward_micro_batch,
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)
|
||||||
|
per_token_kl = (
|
||||||
|
torch.exp(reference_action_log_probs - action_log_probs)
|
||||||
|
- (reference_action_log_probs - action_log_probs)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||||
|
action_mask_forward_micro_batch, dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
loss, skip_update, _ = self.policy_loss_fn(
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
old_action_log_probs,
|
old_action_log_probs,
|
||||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
per_token_kl,
|
per_token_kl,
|
||||||
action_mask_forward_micro_batch,
|
action_mask_forward_micro_batch,
|
||||||
loss_mask=loss_mask_forward_micro_batch,
|
loss_mask=loss_mask_forward_micro_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not skip_update:
|
if not skip_update:
|
||||||
self.booster.backward(loss, self.optimizer)
|
self.booster.backward(loss, self.optimizer)
|
||||||
loss = all_reduce_mean(loss, self.plugin)
|
loss = all_reduce_mean(loss, self.plugin)
|
||||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||||
# Calculate accumulate value.
|
# Calculate accumulate value.
|
||||||
mean_kl.append(kl.data)
|
mean_kl.append(kl.data)
|
||||||
mean_loss.append(loss.data)
|
mean_loss.append(loss.data)
|
||||||
|
if not self.plugin.pp_size > 1 or (
|
||||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
):
|
||||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||||
self.accum_reward.add_(reward.data)
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||||
self.accum_format_reward.add_(format_reward.data)
|
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||||
self.accum_acc_reward.add_(acc_reward.data)
|
self.accum_reward.add_(reward.data)
|
||||||
self.accum_advantages.add_(advantages.data)
|
self.accum_format_reward.add_(format_reward.data)
|
||||||
self.accum_response_length.add_(response_length.data)
|
self.accum_acc_reward.add_(acc_reward.data)
|
||||||
self.accum_count += 1
|
self.accum_advantages.add_(advantages.data)
|
||||||
|
self.accum_response_length.add_(response_length.data)
|
||||||
|
self.accum_count += 1
|
||||||
if need_update:
|
if need_update:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss_scalar = self.accum_loss.item()
|
if not self.plugin.pp_size > 1 or (
|
||||||
if self.rank == 0:
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||||
print(
|
):
|
||||||
"Loss:",
|
loss_scalar = self.accum_loss.item()
|
||||||
self.accum_loss.item() / self.accum_count,
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
"\nReward:",
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||||
self.accum_reward.item() / self.accum_count,
|
):
|
||||||
"\nFormat Reward:",
|
print(
|
||||||
self.accum_format_reward.item() / self.accum_count,
|
"Loss:",
|
||||||
"\nAcc Reward:",
|
self.accum_loss.item() / self.accum_count,
|
||||||
self.accum_acc_reward.item() / self.accum_count,
|
"\nReward:",
|
||||||
"\nKL:",
|
self.accum_reward.item() / self.accum_count,
|
||||||
self.accum_kl.item() / self.accum_count,
|
"\nFormat Reward:",
|
||||||
"\nAdvantages:",
|
self.accum_format_reward.item() / self.accum_count,
|
||||||
self.accum_advantages.item() / self.accum_count,
|
"\nAcc Reward:",
|
||||||
"\nResponse Length:",
|
self.accum_acc_reward.item() / self.accum_count,
|
||||||
self.accum_response_length.item() / self.accum_count,
|
"\nKL:",
|
||||||
)
|
self.accum_kl.item() / self.accum_count,
|
||||||
self.wandb_run.log(
|
"\nAdvantages:",
|
||||||
{
|
self.accum_advantages.item() / self.accum_count,
|
||||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
"\nResponse Length:",
|
||||||
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
|
self.accum_response_length.item() / self.accum_count,
|
||||||
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
)
|
||||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
self.wandb_run.log(
|
||||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
{
|
||||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||||
}
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||||
)
|
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||||
self.accum_loss.zero_()
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
self.accum_reward.zero_()
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||||
self.accum_acc_reward.zero_()
|
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||||
self.accum_format_reward.zero_()
|
}
|
||||||
self.accum_kl.zero_()
|
)
|
||||||
self.accum_advantages.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_response_length.zero_()
|
self.accum_reward.zero_()
|
||||||
|
self.accum_acc_reward.zero_()
|
||||||
|
self.accum_format_reward.zero_()
|
||||||
|
self.accum_kl.zero_()
|
||||||
|
self.accum_advantages.zero_()
|
||||||
|
self.accum_response_length.zero_()
|
||||||
|
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
return loss_scalar
|
return loss_scalar
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
self.policy_model._force_wait_all_gather()
|
self.policy_model._force_wait_all_gather()
|
||||||
|
@ -47,6 +47,7 @@ def launch_distributed(
|
|||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: int = 29500,
|
master_port: int = 29500,
|
||||||
core_algo: str = "GRPO",
|
core_algo: str = "GRPO",
|
||||||
|
project_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
@ -108,6 +109,7 @@ def launch_distributed(
|
|||||||
"train_microbatch_size": train_microbatch_size,
|
"train_microbatch_size": train_microbatch_size,
|
||||||
},
|
},
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
|
project_name=project_name,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -10,13 +10,44 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||||
parser.add_argument("-g", "--num-generations", type=int, default=8)
|
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
||||||
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
|
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||||
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
|
parser.add_argument(
|
||||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
|
"-ibs",
|
||||||
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
|
"--inference-batch-size",
|
||||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
type=int,
|
||||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
default=64,
|
||||||
|
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-imbs",
|
||||||
|
"--inference-microbatch-size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-tbs",
|
||||||
|
"--train-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-tMbs",
|
||||||
|
"--train-minibatch-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-tmbs",
|
||||||
|
"--train-microbatch-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||||
|
)
|
||||||
|
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -29,11 +60,7 @@ if __name__ == "__main__":
|
|||||||
ray.init(address="local", namespace="ray-example")
|
ray.init(address="local", namespace="ray-example")
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(
|
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||||
path=args.model,
|
|
||||||
# use_flash_attention_2=True,
|
|
||||||
# use_cache=False
|
|
||||||
)
|
|
||||||
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
@ -91,9 +118,17 @@ if __name__ == "__main__":
|
|||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
num_generations=args.num_generations,
|
num_generations=args.num_generations,
|
||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
plugin_config={},
|
# plugin_config={}, # for zero
|
||||||
|
plugin_config={
|
||||||
|
"pp_size": 2,
|
||||||
|
"tp_size": 1,
|
||||||
|
"microbatch_size": args.train_microbatch_size // 2,
|
||||||
|
"zero_stage": 0,
|
||||||
|
"max_norm": 1.0,
|
||||||
|
}, # for pp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=29505,
|
master_port=29506,
|
||||||
core_algo=args.algo,
|
core_algo=args.algo,
|
||||||
|
project_name=args.project,
|
||||||
)
|
)
|
||||||
|
@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# run with gradients accumulation
|
# run with gradients accumulation
|
||||||
if model.require_grad_sync == False or (
|
if (
|
||||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
not torch.is_grad_enabled()
|
||||||
|
or model.require_grad_sync == False
|
||||||
|
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
|
||||||
):
|
):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -284,6 +284,7 @@ class Qwen2PipelineForwards:
|
|||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
Loading…
Reference in New Issue
Block a user