mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
fix logprob, add filtering, temperature annealing, lr descent
This commit is contained in:
@@ -117,6 +117,12 @@ class BaseProducer:
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
self.load_state_dict(state_dict)
|
||||
# linear annealing for 1 episode, temperature from initial to 0.7
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||
self.model.generate_config.temperature = (
|
||||
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
|
||||
)
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -135,6 +141,7 @@ class SimpleProducer(BaseProducer):
|
||||
tokenizer_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
@@ -150,7 +157,7 @@ class SimpleProducer(BaseProducer):
|
||||
microbatch_size,
|
||||
backend,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
|
Reference in New Issue
Block a user