fix logprob, add filtering, temperature annealing, lr descent

This commit is contained in:
YeAnbang
2025-03-21 10:24:24 +08:00
parent 7ee4452f8c
commit 0472f44163
7 changed files with 74 additions and 27 deletions

View File

@@ -117,6 +117,12 @@ class BaseProducer:
None, self.num_producers, device=self.device, group_name="sync_model"
)
self.load_state_dict(state_dict)
# linear annealing for 1 episode, temperature from initial to 0.7
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
)
@ray.remote
@@ -135,6 +141,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
):
super().__init__(
producer_idx,
@@ -150,7 +157,7 @@ class SimpleProducer(BaseProducer):
microbatch_size,
backend,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):