Unified configuration

This commit is contained in:
yhjun1026
2023-05-18 16:30:57 +08:00
parent a68e164a5f
commit ba7e23d37f
13 changed files with 97 additions and 53 deletions

View File

@@ -8,8 +8,9 @@ from langchain.embeddings.base import Embeddings
from pydantic import BaseModel
from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM
from pilot.configs.model_config import *
from pilot.configs.config import Config
CFG = Config()
class VicunaLLM(LLM):
vicuna_generate_path = "generate_stream"
@@ -22,7 +23,7 @@ class VicunaLLM(LLM):
"stop": stop
}
response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params),
)
@@ -51,7 +52,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
print("Sending prompt ", p)
response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path),
url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
json={
"prompt": p
}