mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-18 00:07:45 +00:00
checkout load_model to fastchat
This commit is contained in:
parent
4381e21d0c
commit
7c69dc248a
@ -85,3 +85,4 @@ def get_embeddings(model, tokenizer, prompt):
|
||||
embeddings = input_embeddings(torch.LongTensor([input_ids]))
|
||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||
return mean
|
||||
|
||||
|
@ -1,17 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import uvicorn
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
from fastchat.serve.inference import load_model
|
||||
from pilot.model.loader import ModerLoader
|
||||
from pilot.configs.model_config import *
|
||||
|
||||
model_path = llm_model_config[LLM_MODEL]
|
||||
ml = ModerLoader(model_path=model_path)
|
||||
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||
# ml = ModerLoader(model_path=model_path)
|
||||
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||
|
||||
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||
app = FastAPI()
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
@ -46,3 +49,8 @@ def embeddings(prompt_request: EmbeddingRequest):
|
||||
print("Received prompt: ", params["prompt"])
|
||||
output = get_embeddings(model, tokenizer, params["prompt"])
|
||||
return {"response": [float(x) for x in output]}
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", log_level="info")
|
Loading…
Reference in New Issue
Block a user