mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 14:57:35 +00:00
fix multi device error
This commit is contained in:
parent
bb6c1865e1
commit
e493f54804
@ -81,8 +81,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_embeddings(model, tokenizer, prompt):
|
def get_embeddings(model, tokenizer, prompt):
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
input_embeddings = model.get_input_embeddings()
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
embeddings = input_embeddings(torch.LongTensor([input_ids]))
|
input_embeddings = model.get_input_embeddings().to(device)
|
||||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
|
||||||
return mean
|
|
||||||
|
|
||||||
|
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
|
||||||
|
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||||
|
return mean.to(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user