mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 14:57:35 +00:00
94 lines
2.6 KiB
Python
94 lines
2.6 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
|
|
import json
|
|
from typing import Any, List, Mapping, Optional
|
|
from urllib.parse import urljoin
|
|
|
|
import requests
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain.llms.base import LLM
|
|
|
|
from dbgpt._private.config import Config
|
|
from dbgpt._private.pydantic import BaseModel
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class VicunaLLM(LLM):
|
|
vicuna_generate_path = "generate_stream"
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
temperature: float,
|
|
max_new_tokens: int,
|
|
stop: Optional[List[str]] = None,
|
|
) -> str:
|
|
params = {
|
|
"prompt": prompt,
|
|
"temperature": temperature,
|
|
"max_new_tokens": max_new_tokens,
|
|
"stop": stop,
|
|
}
|
|
response = requests.post(
|
|
url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
|
|
data=json.dumps(params),
|
|
)
|
|
|
|
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
|
if chunk:
|
|
data = json.loads(chunk.decode())
|
|
if data["error_code"] == 0:
|
|
output = data["text"][skip_echo_len:].strip()
|
|
yield output
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "custome"
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
return {}
|
|
|
|
|
|
class VicunaEmbeddingLLM(BaseModel, Embeddings):
|
|
vicuna_embedding_path = "embedding"
|
|
|
|
def _call(self, prompt: str) -> str:
|
|
p = prompt.strip()
|
|
print("Sending prompt ", p)
|
|
|
|
response = requests.post(
|
|
url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
|
|
json={"prompt": p},
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()["response"]
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Call out to Vicuna's server embedding endpoint for embedding search docs.
|
|
|
|
Args:
|
|
texts: The list of text to embed
|
|
|
|
Returns:
|
|
List of embeddings. one for each text.
|
|
"""
|
|
results = []
|
|
for text in texts:
|
|
response = self.embed_query(text)
|
|
results.append(response)
|
|
return results
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Call out to Vicuna's server embedding endpoint for embedding query text.
|
|
|
|
Args:
|
|
text: The text to embed.
|
|
Returns:
|
|
Embedding for the text
|
|
"""
|
|
embedding = self._call(text)
|
|
return embedding
|