mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-11 21:22:28 +00:00
chore(model): Modify siliconflow name (#2194)
This commit is contained in:
parent
4da1809b31
commit
dab55493f5
@ -76,6 +76,8 @@ KNOWLEDGE_GRAPH_SEARCH_TOP_SIZE=200
|
||||
## you can set this value to a higher value for better performance.
|
||||
## if out of memory when load large document, you can set this value to a lower value.
|
||||
# KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD=10
|
||||
## Maximum number of threads to use when loading chunks, please make sure your vector db can support multi-threading.
|
||||
# KNOWLEDGE_MAX_THREADS=1
|
||||
#KNOWLEDGE_CHUNK_OVERLAP=50
|
||||
# Control whether to display the source document of knowledge on the front end.
|
||||
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
|
||||
|
@ -140,16 +140,16 @@ class Config(metaclass=Singleton):
|
||||
os.environ["claude_proxyllm_api_base"] = os.getenv(
|
||||
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
|
||||
)
|
||||
self.silicon_flow_proxy_api_key = os.getenv("SILICON_FLOW_API_KEY")
|
||||
if self.silicon_flow_proxy_api_key:
|
||||
self.siliconflow_proxy_api_key = os.getenv("SILICONFLOW_API_KEY")
|
||||
if self.siliconflow_proxy_api_key:
|
||||
os.environ[
|
||||
"silicon_flow_proxyllm_proxy_api_key"
|
||||
] = self.silicon_flow_proxy_api_key
|
||||
os.environ["silicon_flow_proxyllm_proxyllm_backend"] = os.getenv(
|
||||
"SILICON_FLOW_MODEL_VERSION", "Qwen/Qwen2.5-Coder-32B-Instruct"
|
||||
"siliconflow_proxyllm_proxy_api_key"
|
||||
] = self.siliconflow_proxy_api_key
|
||||
os.environ["siliconflow_proxyllm_proxyllm_backend"] = os.getenv(
|
||||
"SILICONFLOW_MODEL_VERSION", "Qwen/Qwen2.5-Coder-32B-Instruct"
|
||||
)
|
||||
os.environ["silicon_flow_proxyllm_api_base"] = os.getenv(
|
||||
"SILICON_FLOW_API_BASE", "https://api.siliconflow.cn/v1"
|
||||
os.environ["siliconflow_proxyllm_api_base"] = os.getenv(
|
||||
"SILICONFLOW_API_BASE", "https://api.siliconflow.cn/v1"
|
||||
)
|
||||
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
@ -277,6 +277,7 @@ class Config(metaclass=Singleton):
|
||||
self.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD = int(
|
||||
os.getenv("KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD", 10)
|
||||
)
|
||||
self.KNOWLEDGE_MAX_THREADS = int(os.getenv("KNOWLEDGE_MAX_THREADS", 1))
|
||||
# default recall similarity score, between 0 and 1
|
||||
self.KNOWLEDGE_SEARCH_RECALL_SCORE = float(
|
||||
os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3)
|
||||
|
@ -80,7 +80,7 @@ LLM_MODEL_CONFIG = {
|
||||
# https://platform.deepseek.com/api-docs/
|
||||
"deepseek_proxyllm": "deepseek_proxyllm",
|
||||
# https://docs.siliconflow.cn/quickstart
|
||||
"silicon_flow_proxyllm": "silicon_flow_proxyllm",
|
||||
"siliconflow_proxyllm": "siliconflow_proxyllm",
|
||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||
@ -322,7 +322,7 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
||||
# Proxy rerank model
|
||||
"rerank_proxy_http_openapi": "rerank_proxy_http_openapi",
|
||||
"rerank_proxy_silicon_flow": "rerank_proxy_silicon_flow",
|
||||
"rerank_proxy_siliconflow": "rerank_proxy_siliconflow",
|
||||
}
|
||||
|
||||
|
||||
|
@ -109,7 +109,7 @@ class EmbeddingLoader:
|
||||
if proxy_param.proxy_backend:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return OpenAPIRerankEmbeddings(**openapi_param)
|
||||
elif model_name in ["rerank_proxy_silicon_flow"]:
|
||||
elif model_name in ["rerank_proxy_siliconflow"]:
|
||||
from dbgpt.rag.embedding.rerank import SiliconFlowRerankEmbeddings
|
||||
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
|
@ -349,7 +349,7 @@ class SiliconFlowProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
return True
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "silicon_flow_proxyllm"
|
||||
return lower_model_name_or_path == "siliconflow_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
@ -359,9 +359,9 @@ class SiliconFlowProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
return SiliconFlowLLMClient
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.siliconflow import silicon_flow_generate_stream
|
||||
from dbgpt.model.proxy.llms.siliconflow import siliconflow_generate_stream
|
||||
|
||||
return silicon_flow_generate_stream
|
||||
return siliconflow_generate_stream
|
||||
|
||||
|
||||
register_model_adapter(OpenAIProxyLLMModelAdapter)
|
||||
|
@ -621,7 +621,7 @@ _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||
"proxy_tongyi",
|
||||
"proxy_qianfan",
|
||||
"rerank_proxy_http_openapi",
|
||||
"rerank_proxy_silicon_flow",
|
||||
"rerank_proxy_siliconflow",
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -12,10 +12,10 @@ if TYPE_CHECKING:
|
||||
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
|
||||
|
||||
|
||||
_SILICON_FLOW_DEFAULT_MODEL = "Qwen/Qwen2.5-Coder-32B-Instruct"
|
||||
_SILICONFLOW_DEFAULT_MODEL = "Qwen/Qwen2.5-Coder-32B-Instruct"
|
||||
|
||||
|
||||
async def silicon_flow_generate_stream(
|
||||
async def siliconflow_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: SiliconFlowLLMClient = model.proxy_llm_client
|
||||
@ -39,7 +39,7 @@ class SiliconFlowLLMClient(OpenAILLMClient):
|
||||
model: Optional[str] = None,
|
||||
proxies: Optional["ProxiesTypes"] = None,
|
||||
timeout: Optional[int] = 240,
|
||||
model_alias: Optional[str] = "silicon_flow_proxyllm",
|
||||
model_alias: Optional[str] = "siliconflow_proxyllm",
|
||||
context_length: Optional[int] = None,
|
||||
openai_client: Optional["ClientType"] = None,
|
||||
openai_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@ -47,11 +47,11 @@ class SiliconFlowLLMClient(OpenAILLMClient):
|
||||
):
|
||||
api_base = (
|
||||
api_base
|
||||
or os.getenv("SILICON_FLOW_API_BASE")
|
||||
or os.getenv("SILICONFLOW_API_BASE")
|
||||
or "https://api.siliconflow.cn/v1"
|
||||
)
|
||||
api_key = api_key or os.getenv("SILICON_FLOW_API_KEY")
|
||||
model = model or _SILICON_FLOW_DEFAULT_MODEL
|
||||
api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
||||
model = model or _SILICONFLOW_DEFAULT_MODEL
|
||||
if not context_length:
|
||||
if "200k" in model:
|
||||
context_length = 200 * 1024
|
||||
@ -60,7 +60,7 @@ class SiliconFlowLLMClient(OpenAILLMClient):
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"SiliconFlow API key is required, please set 'SILICON_FLOW_API_KEY' in environment "
|
||||
"SiliconFlow API key is required, please set 'SILICONFLOW_API_KEY' in environment "
|
||||
"or pass it as an argument."
|
||||
)
|
||||
|
||||
@ -83,5 +83,5 @@ class SiliconFlowLLMClient(OpenAILLMClient):
|
||||
def default_model(self) -> str:
|
||||
model = self._model
|
||||
if not model:
|
||||
model = _SILICON_FLOW_DEFAULT_MODEL
|
||||
model = _SILICONFLOW_DEFAULT_MODEL
|
||||
return model
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Embedding Assembler."""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@ -130,7 +131,11 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._index_store.load_document(self._chunks)
|
||||
max_chunks_once_load = kwargs.get("max_chunks_once_load", 10)
|
||||
max_threads = kwargs.get("max_threads", 1)
|
||||
return self._index_store.load_document_with_limit(
|
||||
self._chunks, max_chunks_once_load, max_threads
|
||||
)
|
||||
|
||||
async def apersist(self, **kwargs) -> List[str]:
|
||||
"""Persist chunks into store.
|
||||
@ -139,7 +144,11 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
# persist chunks into vector store
|
||||
return await self._index_store.aload_document(self._chunks)
|
||||
max_chunks_once_load = kwargs.get("max_chunks_once_load", 10)
|
||||
max_threads = kwargs.get("max_threads", 1)
|
||||
return await self._index_store.aload_document_with_limit(
|
||||
self._chunks, max_chunks_once_load, max_threads
|
||||
)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
|
@ -175,10 +175,10 @@ class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
|
||||
"""Initialize the SiliconFlowRerankEmbeddings."""
|
||||
# If the API key is not provided, try to get it from the environment
|
||||
if "api_key" not in kwargs:
|
||||
kwargs["api_key"] = os.getenv("SILICON_FLOW_API_KEY")
|
||||
kwargs["api_key"] = os.getenv("SILICONFLOW_API_KEY")
|
||||
|
||||
if "api_url" not in kwargs:
|
||||
env_api_url = os.getenv("SILICON_FLOW_API_BASE")
|
||||
env_api_url = os.getenv("SILICONFLOW_API_BASE")
|
||||
if env_api_url:
|
||||
env_api_url = env_api_url.rstrip("/")
|
||||
kwargs["api_url"] = env_api_url + "/rerank"
|
||||
|
@ -506,6 +506,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
|
||||
max_threads=CFG.KNOWLEDGE_MAX_THREADS,
|
||||
llm_client=self.llm_client,
|
||||
model_name=None,
|
||||
)
|
||||
@ -567,6 +568,10 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
vector_ids = [chunk.chunk_id for chunk in chunk_docs]
|
||||
else:
|
||||
max_chunks_once_load = (
|
||||
vector_store_connector._index_store_config.max_chunks_once_load
|
||||
)
|
||||
max_threads = vector_store_connector._index_store_config.max_threads
|
||||
assembler = await EmbeddingAssembler.aload_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
index_store=vector_store_connector.index_client,
|
||||
@ -575,7 +580,10 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
|
||||
chunk_docs = assembler.get_chunks()
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
vector_ids = await assembler.apersist()
|
||||
vector_ids = await assembler.apersist(
|
||||
max_chunks_once_load=max_chunks_once_load,
|
||||
max_threads=max_threads,
|
||||
)
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document persist into index store success"
|
||||
if vector_ids is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user