This commit is contained in:
csunny 2023-04-30 14:52:03 +08:00
parent e493f54804
commit 909045c0d6

View File

@ -9,11 +9,8 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
stop_parameter = params.get("stop", None)
print(tokenizer.__dir__())
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
stop_strings = []
if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter)
@ -43,33 +40,34 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_value
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=1)
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
for stop_str in stop_strings:
pos = output.rfind(stop_str)
if pos != -1:
output = output[:pos]
stoppped = True
stopped = True
break
else:
pass
if stoppped:
if stopped:
break
del past_key_values