mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[Distributed RLHF] Integration of PP (#6257)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
@@ -10,13 +10,44 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||
parser.add_argument("-g", "--num-generations", type=int, default=8)
|
||||
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
|
||||
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
|
||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
|
||||
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
|
||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||
parser.add_argument(
|
||||
"-ibs",
|
||||
"--inference-batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-imbs",
|
||||
"--inference-microbatch-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tbs",
|
||||
"--train-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tMbs",
|
||||
"--train-minibatch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tmbs",
|
||||
"--train-microbatch-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||
)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -29,11 +60,7 @@ if __name__ == "__main__":
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(
|
||||
path=args.model,
|
||||
# use_flash_attention_2=True,
|
||||
# use_cache=False
|
||||
)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
||||
|
||||
if args.backend == "transformers":
|
||||
@@ -91,9 +118,17 @@ if __name__ == "__main__":
|
||||
generate_config=generate_config,
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
plugin_config={},
|
||||
# plugin_config={}, # for zero
|
||||
plugin_config={
|
||||
"pp_size": 2,
|
||||
"tp_size": 1,
|
||||
"microbatch_size": args.train_microbatch_size // 2,
|
||||
"zero_stage": 0,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=29505,
|
||||
master_port=29506,
|
||||
core_algo=args.algo,
|
||||
project_name=args.project,
|
||||
)
|
||||
|
Reference in New Issue
Block a user