mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
fix multi device error
This commit is contained in:
parent
bb6c1865e1
commit
e493f54804
@ -81,8 +81,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
@torch.inference_mode()
|
||||
def get_embeddings(model, tokenizer, prompt):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
input_embeddings = model.get_input_embeddings()
|
||||
embeddings = input_embeddings(torch.LongTensor([input_ids]))
|
||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||
return mean
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
input_embeddings = model.get_input_embeddings().to(device)
|
||||
|
||||
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
|
||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||
return mean.to(device)
|
||||
|
Loading…
Reference in New Issue
Block a user