fix: Fix retrieve error

This commit is contained in:
Fangyin Cheng 2025-03-17 15:37:02 +08:00
parent b2dd66dc6d
commit 9f2e747698
13 changed files with 22 additions and 26 deletions

View File

@ -26,11 +26,11 @@ rerank_top_k=3
[rag.storage]
[rag.storage.vector]
type = "Chroma"
type = "chroma"
persist_path = "pilot/data"
[rag.storage.full_text]
type = "ElasticSearch"
type = "elasticsearch"
host="127.0.0.1"
port=9200

View File

@ -147,7 +147,7 @@ class RegisterParameters(abc.ABC, metaclass=PolymorphicMeta):
"""Register a subclass with this base class using a type value."""
if not hasattr(cls, "_type_registry"):
cls._type_registry = {}
cls._type_registry[type_value] = subclass
cls._type_registry[type_value.lower()] = subclass
@classmethod
def get_type_value(cls) -> str:
@ -238,7 +238,7 @@ def _get_concrete_class(base_class: Type[T], data: Dict[str, Any]) -> Type[T]:
if not type_value:
return base_class
type_value = _resolve_env_vars(type_value)
real_cls = base_class.get_subclass(type_value)
real_cls = base_class.get_subclass(type_value.lower())
if not real_cls:
raise ValueError(
f"Unknown type value: {type_value}, known types: "

View File

@ -1,17 +1,18 @@
"""Neo4j store."""
import logging
from dataclasses import dataclass
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
logger = logging.getLogger(__name__)
@dataclass
class Neo4jStoreConfig(GraphStoreConfig):
"""Neo4j store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
__type__ = "neo4j"
class Neo4jStore(GraphStoreBase):

View File

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class TuGraphStoreConfig(GraphStoreConfig):
"""TuGraph store config."""
__type__ = "TuGraph"
__type__ = "tugraph"
host: str = field(
default="127.0.0.1",

View File

@ -103,7 +103,7 @@ GRAPH_PARAMETERS = [
class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig):
"""Builtin knowledge graph config."""
__type__ = "TuGraph"
__type__ = "tugraph"
llm_model: Optional[str] = field(
default=None, metadata={"description": "llm model name."}

View File

@ -1,17 +1,18 @@
"""OpenSPG class."""
import logging
from dataclasses import dataclass
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig
logger = logging.getLogger(__name__)
@dataclass
class OpenSPGConfig(KnowledgeGraphConfig):
"""OpenSPG config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
__type__ = "openspg"
class OpenSPG(KnowledgeGraphBase):

View File

@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class ChromaVectorConfig(VectorStoreConfig):
"""Chroma vector store config."""
__type__ = "Chroma"
__type__ = "chroma"
persist_path: Optional[str] = field(
default=os.getenv("CHROMA_PERSIST_PATH", None),
@ -249,16 +249,6 @@ class ChromaStore(VectorStoreBase):
# Delete collection if it exists
self._chroma_client.delete_collection(self._collection.name)
SharedSystemClient.clear_system_cache()
# Clean persist folder if it exists
if os.path.exists(self.persist_dir):
try:
self._clean_persist_folder()
except Exception as e:
logger.error(f"Failed to clean persist folder: {e}")
# Even if folder cleanup fails, collection deletion succeeded
return True
return True
except Exception as e:

View File

@ -77,7 +77,7 @@ logger = logging.getLogger(__name__)
class ElasticsearchStoreConfig(VectorStoreConfig):
"""Elasticsearch vector store config."""
__type__ = "ElasticSearch"
__type__ = "elasticsearch"
uri: str = field(
default="localhost",

View File

@ -98,7 +98,7 @@ logger = logging.getLogger(__name__)
class MilvusVectorConfig(VectorStoreConfig):
"""Milvus vector store config."""
__type__ = "Milvus"
__type__ = "milvus"
uri: str = field(
default=None,

View File

@ -126,7 +126,7 @@ def _normalize(vector: List[float]) -> List[float]:
class OceanBaseConfig(VectorStoreConfig):
"""OceanBase vector store config."""
__type__ = "OceanBase"
__type__ = "oceanbase"
ob_host: Optional[str] = field(
default=None,

View File

@ -41,6 +41,8 @@ logger = logging.getLogger(__name__)
class PGVectorConfig(VectorStoreConfig):
"""PG vector store config."""
__type__ = "pgvector"
connection_string: str = field(
default=None,
metadata={
@ -50,7 +52,7 @@ class PGVectorConfig(VectorStoreConfig):
)
def create_store(self, **kwargs) -> "PGVectorStore":
"""Create Milvus Store."""
"""Create a PGVectorStore instance."""
return PGVectorStore(vector_store_config=self, **kwargs)

View File

@ -49,6 +49,8 @@ logger = logging.getLogger(__name__)
class WeaviateVectorConfig(VectorStoreConfig):
"""Weaviate vector store config."""
__type__ = "weaviate"
weaviate_url: str = field(
default=os.getenv("WEAVIATE_URL", None),
metadata={

View File

@ -39,7 +39,7 @@ class StorageManager(BaseComponent):
"""Get storage connector."""
supported_vector_types = self.get_vector_supported_types
storage_config = self.storage_config()
if storage_type in supported_vector_types:
if storage_type.lower() in supported_vector_types:
return self.create_vector_store(index_name)
elif storage_type == "KnowledgeGraph":
if not storage_config.graph: