address conversation

This commit is contained in:
YeAnbang 2025-04-09 12:53:40 +08:00
parent 061d8cb3b6
commit a40d82f629
3 changed files with 30 additions and 25 deletions

View File

@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer):
use_wandb=True,
generate_config=None,
training_config={},
project_name=None,
):
super().__init__(
num_producers,
@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_count = 0
self.generate_config = generate_config
self.training_config = training_config
self.project_name = project_name
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -111,7 +113,7 @@ class GRPOConsumer(BaseConsumer):
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name)
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
@ -239,6 +241,8 @@ class GRPOConsumer(BaseConsumer):
"source": self.rank,
}
kl = []
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = calc_action_log_probs(
@ -252,6 +256,10 @@ class GRPOConsumer(BaseConsumer):
- (inputs["reference_action_log_probs"] - action_log_probs)
- 1
)
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
inputs["action_mask"], dim=-1
)
kl.append(appox_kl.mean())
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
@ -273,26 +281,11 @@ class GRPOConsumer(BaseConsumer):
loss = policy_model_outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
# calculate kl, as we cannot do this inside callback, kl needs be calculate again
action_logits = policy_model_outputs["outputs"]["logits"]
action_log_probs = calc_action_log_probs(
action_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
kl = all_reduce_mean(kl.mean(), self.plugin)
if len(kl) > 0:
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
mean_kl.append(kl)
loss = all_reduce_mean(loss, self.plugin)
mean_loss.append(loss.data)
mean_kl.append(kl)
else:
policy_model_logits = self.policy_model(

View File

@ -47,6 +47,7 @@ def launch_distributed(
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
project_name: Optional[str] = None,
):
if core_algo not in ALGO_MAP:
@ -108,6 +109,7 @@ def launch_distributed(
"train_microbatch_size": train_microbatch_size,
},
num_generations=num_generations,
project_name=project_name,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])

View File

@ -11,32 +11,41 @@ if __name__ == "__main__":
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument(
"-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step."
"-ibs",
"--inference-batch-size",
type=int,
default=64,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
)
parser.add_argument(
"-imbs",
"--inference-microbatch-size",
type=int,
default=8,
help="Number of prompts to send from the producer to the consumer.",
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
)
parser.add_argument(
"-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model."
"-tbs",
"--train-batch-size",
type=int,
default=32,
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
)
parser.add_argument(
"-tMbs",
"--train-minibatch-size",
type=int,
default=1,
help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.",
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
)
parser.add_argument(
"-tmbs",
"--train-microbatch-size",
type=int,
default=2,
help="Number of samples per device. PP micro batchsize when PP is activated.",
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
)
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
@ -119,6 +128,7 @@ if __name__ == "__main__":
}, # for pp
inference_backend=args.backend,
master_addr="localhost",
master_port=29505,
master_port=29506,
core_algo=args.algo,
project_name=args.project,
)