mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
address conversation
This commit is contained in:
parent
061d8cb3b6
commit
a40d82f629
@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
use_wandb=True,
|
||||
generate_config=None,
|
||||
training_config={},
|
||||
project_name=None,
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.training_config = training_config
|
||||
self.project_name = project_name
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@ -111,7 +113,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
):
|
||||
# Initialize wandb.
|
||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||
self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name)
|
||||
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
|
||||
|
||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||
@ -239,6 +241,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
"source": self.rank,
|
||||
}
|
||||
|
||||
kl = []
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
action_log_probs = calc_action_log_probs(
|
||||
@ -252,6 +256,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
||||
inputs["action_mask"], dim=-1
|
||||
)
|
||||
kl.append(appox_kl.mean())
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
action_log_probs,
|
||||
@ -273,26 +281,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss = policy_model_outputs["loss"]
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
# calculate kl, as we cannot do this inside callback, kl needs be calculate again
|
||||
action_logits = policy_model_outputs["outputs"]["logits"]
|
||||
action_log_probs = calc_action_log_probs(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
if len(kl) > 0:
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
|
||||
mean_kl.append(kl)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
mean_loss.append(loss.data)
|
||||
mean_kl.append(kl)
|
||||
else:
|
||||
|
||||
policy_model_logits = self.policy_model(
|
||||
|
@ -47,6 +47,7 @@ def launch_distributed(
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
project_name: Optional[str] = None,
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
@ -108,6 +109,7 @@ def launch_distributed(
|
||||
"train_microbatch_size": train_microbatch_size,
|
||||
},
|
||||
num_generations=num_generations,
|
||||
project_name=project_name,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
|
@ -11,32 +11,41 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||
parser.add_argument(
|
||||
"-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step."
|
||||
"-ibs",
|
||||
"--inference-batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-imbs",
|
||||
"--inference-microbatch-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of prompts to send from the producer to the consumer.",
|
||||
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model."
|
||||
"-tbs",
|
||||
"--train-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tMbs",
|
||||
"--train-minibatch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.",
|
||||
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tmbs",
|
||||
"--train-microbatch-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of samples per device. PP micro batchsize when PP is activated.",
|
||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||
)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||
@ -119,6 +128,7 @@ if __name__ == "__main__":
|
||||
}, # for pp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=29505,
|
||||
master_port=29506,
|
||||
core_algo=args.algo,
|
||||
project_name=args.project,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user