checkout load_model to fastchat

This commit is contained in:
csunny
2023-04-30 14:21:58 +08:00
parent 4381e21d0c
commit 7c69dc248a
2 changed files with 13 additions and 4 deletions

View File

@@ -84,4 +84,5 @@ def get_embeddings(model, tokenizer, prompt):
input_embeddings = model.get_input_embeddings()
embeddings = input_embeddings(torch.LongTensor([input_ids]))
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean
return mean