[feat] add microbatch forwarding (#6251)

* add microbatch forwarding

* fix forward microbatch

* fix producer OOM

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change project name

* fix temperature annealing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address conversation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-03-28 10:24:58 +08:00
committed by GitHub
parent 489f215ad9
commit 50153005b4
5 changed files with 112 additions and 72 deletions

View File

@@ -10,18 +10,30 @@ if __name__ == "__main__":
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8)
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args()
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
assert (
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model)
train_model_config = dict(
path=args.model,
# use_flash_attention_2=True,
# use_cache=False
)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers":
@@ -31,13 +43,6 @@ if __name__ == "__main__":
torch_dtype=torch.bfloat16,
)
)
train_model_config.update(
dict(
use_flash_attention_2=True,
torch_dtype=torch.bfloat16,
use_cache=False,
)
)
generate_config.update(
dict(
max_length=1024 + 512,
@@ -78,15 +83,17 @@ if __name__ == "__main__":
inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size,
dataset_config={"path": args.dataset, "max_length": 300},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
plugin_config={},
inference_backend=args.backend,
master_addr="localhost",
master_port=29503,
master_port=29505,
core_algo=args.algo,
)