mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
fix logprob, add filtering, temperature annealing, lr descent
This commit is contained in:
@@ -42,6 +42,7 @@ def launch_distributed(
|
||||
plugin_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
num_generations: int = 8,
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
@@ -76,6 +77,7 @@ def launch_distributed(
|
||||
tokenizer_config=tokenizer_config,
|
||||
microbatch_size=inference_microbatch_size,
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
)
|
||||
procs.append(producer)
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
@@ -99,7 +101,8 @@ def launch_distributed(
|
||||
plugin_config=plugin_config,
|
||||
microbatch_size=train_microbatch_size,
|
||||
generate_config=generate_config_consumer,
|
||||
filter_range=[0.05, 9.0],
|
||||
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
|
||||
num_generations=num_generations,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
|
Reference in New Issue
Block a user