diff --git a/pilot/model/inference.py b/pilot/model/inference.py index b15e1d749..60d443f95 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -9,11 +9,8 @@ def generate_output(model, tokenizer, params, device, context_len=2048): temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 256)) stop_parameter = params.get("stop", None) - - print(tokenizer.__dir__()) if stop_parameter == tokenizer.eos_token: stop_parameter = None - stop_strings = [] if isinstance(stop_parameter, str): stop_strings.append(stop_parameter) @@ -43,33 +40,34 @@ def generate_output(model, tokenizer, params, device, context_len=2048): past_key_values=past_key_values, ) logits = out.logits - past_key_values = out.past_key_value + past_key_values = out.past_key_values + last_token_logits = logits[0][-1] if temperature < 1e-4: token = int(torch.argmax(last_token_logits)) else: - probs = torch.softmax(last_token_logits / temperature, dim=1) + probs = torch.softmax(last_token_logits / temperature, dim=-1) token = int(torch.multinomial(probs, num_samples=1)) - + output_ids.append(token) if token == tokenizer.eos_token_id: stopped = True else: stopped = False - + output = tokenizer.decode(output_ids, skip_special_tokens=True) for stop_str in stop_strings: pos = output.rfind(stop_str) if pos != -1: output = output[:pos] - stoppped = True + stopped = True break else: pass - - if stoppped: + + if stopped: break del past_key_values