fix logprob, add filtering, temperature annealing, lr descent

This commit is contained in:
YeAnbang
2025-03-21 10:24:24 +08:00
parent 7ee4452f8c
commit 0472f44163
7 changed files with 74 additions and 27 deletions

View File

@@ -42,6 +42,7 @@ def launch_distributed(
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
num_generations: int = 8,
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
@@ -76,6 +77,7 @@ def launch_distributed(
tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size,
backend=inference_backend,
num_generations=num_generations,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
@@ -99,7 +101,8 @@ def launch_distributed(
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
generate_config=generate_config_consumer,
filter_range=[0.05, 9.0],
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
num_generations=num_generations,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])