mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
fetch top3 similar answer
This commit is contained in:
@@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List
|
||||
from langchain.llms.base import LLM
|
||||
from pilot.configs.model_config import *
|
||||
|
||||
class VicunaRequestLLM(LLM):
|
||||
class VicunaLLM(LLM):
|
||||
|
||||
vicuna_generate_path = "generate_stream"
|
||||
def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
|
||||
|
||||
vicuna_generate_path = "generate"
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
if isinstance(stop, list):
|
||||
stop = stop + ["Observation:"]
|
||||
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop": stop
|
||||
}
|
||||
response = requests.post(
|
||||
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
|
||||
data=json.dumps(params),
|
||||
)
|
||||
response.raise_for_status()
|
||||
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
# if chunk:
|
||||
# data = json.loads(chunk.decode())
|
||||
# if data["error_code"] == 0:
|
||||
# output = data["text"][skip_echo_len:].strip()
|
||||
# output = self.post_process_code(output)
|
||||
# yield output
|
||||
return response.json()["response"]
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
yield output
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
Reference in New Issue
Block a user