fix: Fix retrieve error

This commit is contained in:
Fangyin Cheng 2025-03-17 15:37:02 +08:00
parent b2dd66dc6d
commit 9f2e747698
13 changed files with 22 additions and 26 deletions

View File

@ -26,11 +26,11 @@ rerank_top_k=3
[rag.storage] [rag.storage]
[rag.storage.vector] [rag.storage.vector]
type = "Chroma" type = "chroma"
persist_path = "pilot/data" persist_path = "pilot/data"
[rag.storage.full_text] [rag.storage.full_text]
type = "ElasticSearch" type = "elasticsearch"
host="127.0.0.1" host="127.0.0.1"
port=9200 port=9200

View File

@ -147,7 +147,7 @@ class RegisterParameters(abc.ABC, metaclass=PolymorphicMeta):
"""Register a subclass with this base class using a type value.""" """Register a subclass with this base class using a type value."""
if not hasattr(cls, "_type_registry"): if not hasattr(cls, "_type_registry"):
cls._type_registry = {} cls._type_registry = {}
cls._type_registry[type_value] = subclass cls._type_registry[type_value.lower()] = subclass
@classmethod @classmethod
def get_type_value(cls) -> str: def get_type_value(cls) -> str:
@ -238,7 +238,7 @@ def _get_concrete_class(base_class: Type[T], data: Dict[str, Any]) -> Type[T]:
if not type_value: if not type_value:
return base_class return base_class
type_value = _resolve_env_vars(type_value) type_value = _resolve_env_vars(type_value)
real_cls = base_class.get_subclass(type_value) real_cls = base_class.get_subclass(type_value.lower())
if not real_cls: if not real_cls:
raise ValueError( raise ValueError(
f"Unknown type value: {type_value}, known types: " f"Unknown type value: {type_value}, known types: "

View File

@ -1,17 +1,18 @@
"""Neo4j store.""" """Neo4j store."""
import logging import logging
from dataclasses import dataclass
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class Neo4jStoreConfig(GraphStoreConfig): class Neo4jStoreConfig(GraphStoreConfig):
"""Neo4j store config.""" """Neo4j store config."""
model_config = ConfigDict(arbitrary_types_allowed=True) __type__ = "neo4j"
class Neo4jStore(GraphStoreBase): class Neo4jStore(GraphStoreBase):

View File

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class TuGraphStoreConfig(GraphStoreConfig): class TuGraphStoreConfig(GraphStoreConfig):
"""TuGraph store config.""" """TuGraph store config."""
__type__ = "TuGraph" __type__ = "tugraph"
host: str = field( host: str = field(
default="127.0.0.1", default="127.0.0.1",

View File

@ -103,7 +103,7 @@ GRAPH_PARAMETERS = [
class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig): class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig):
"""Builtin knowledge graph config.""" """Builtin knowledge graph config."""
__type__ = "TuGraph" __type__ = "tugraph"
llm_model: Optional[str] = field( llm_model: Optional[str] = field(
default=None, metadata={"description": "llm model name."} default=None, metadata={"description": "llm model name."}

View File

@ -1,17 +1,18 @@
"""OpenSPG class.""" """OpenSPG class."""
import logging import logging
from dataclasses import dataclass
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class OpenSPGConfig(KnowledgeGraphConfig): class OpenSPGConfig(KnowledgeGraphConfig):
"""OpenSPG config.""" """OpenSPG config."""
model_config = ConfigDict(arbitrary_types_allowed=True) __type__ = "openspg"
class OpenSPG(KnowledgeGraphBase): class OpenSPG(KnowledgeGraphBase):

View File

@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class ChromaVectorConfig(VectorStoreConfig): class ChromaVectorConfig(VectorStoreConfig):
"""Chroma vector store config.""" """Chroma vector store config."""
__type__ = "Chroma" __type__ = "chroma"
persist_path: Optional[str] = field( persist_path: Optional[str] = field(
default=os.getenv("CHROMA_PERSIST_PATH", None), default=os.getenv("CHROMA_PERSIST_PATH", None),
@ -249,16 +249,6 @@ class ChromaStore(VectorStoreBase):
# Delete collection if it exists # Delete collection if it exists
self._chroma_client.delete_collection(self._collection.name) self._chroma_client.delete_collection(self._collection.name)
SharedSystemClient.clear_system_cache() SharedSystemClient.clear_system_cache()
# Clean persist folder if it exists
if os.path.exists(self.persist_dir):
try:
self._clean_persist_folder()
except Exception as e:
logger.error(f"Failed to clean persist folder: {e}")
# Even if folder cleanup fails, collection deletion succeeded
return True
return True return True
except Exception as e: except Exception as e:

View File

@ -77,7 +77,7 @@ logger = logging.getLogger(__name__)
class ElasticsearchStoreConfig(VectorStoreConfig): class ElasticsearchStoreConfig(VectorStoreConfig):
"""Elasticsearch vector store config.""" """Elasticsearch vector store config."""
__type__ = "ElasticSearch" __type__ = "elasticsearch"
uri: str = field( uri: str = field(
default="localhost", default="localhost",

View File

@ -98,7 +98,7 @@ logger = logging.getLogger(__name__)
class MilvusVectorConfig(VectorStoreConfig): class MilvusVectorConfig(VectorStoreConfig):
"""Milvus vector store config.""" """Milvus vector store config."""
__type__ = "Milvus" __type__ = "milvus"
uri: str = field( uri: str = field(
default=None, default=None,

View File

@ -126,7 +126,7 @@ def _normalize(vector: List[float]) -> List[float]:
class OceanBaseConfig(VectorStoreConfig): class OceanBaseConfig(VectorStoreConfig):
"""OceanBase vector store config.""" """OceanBase vector store config."""
__type__ = "OceanBase" __type__ = "oceanbase"
ob_host: Optional[str] = field( ob_host: Optional[str] = field(
default=None, default=None,

View File

@ -41,6 +41,8 @@ logger = logging.getLogger(__name__)
class PGVectorConfig(VectorStoreConfig): class PGVectorConfig(VectorStoreConfig):
"""PG vector store config.""" """PG vector store config."""
__type__ = "pgvector"
connection_string: str = field( connection_string: str = field(
default=None, default=None,
metadata={ metadata={
@ -50,7 +52,7 @@ class PGVectorConfig(VectorStoreConfig):
) )
def create_store(self, **kwargs) -> "PGVectorStore": def create_store(self, **kwargs) -> "PGVectorStore":
"""Create Milvus Store.""" """Create a PGVectorStore instance."""
return PGVectorStore(vector_store_config=self, **kwargs) return PGVectorStore(vector_store_config=self, **kwargs)

View File

@ -49,6 +49,8 @@ logger = logging.getLogger(__name__)
class WeaviateVectorConfig(VectorStoreConfig): class WeaviateVectorConfig(VectorStoreConfig):
"""Weaviate vector store config.""" """Weaviate vector store config."""
__type__ = "weaviate"
weaviate_url: str = field( weaviate_url: str = field(
default=os.getenv("WEAVIATE_URL", None), default=os.getenv("WEAVIATE_URL", None),
metadata={ metadata={

View File

@ -39,7 +39,7 @@ class StorageManager(BaseComponent):
"""Get storage connector.""" """Get storage connector."""
supported_vector_types = self.get_vector_supported_types supported_vector_types = self.get_vector_supported_types
storage_config = self.storage_config() storage_config = self.storage_config()
if storage_type in supported_vector_types: if storage_type.lower() in supported_vector_types:
return self.create_vector_store(index_name) return self.create_vector_store(index_name)
elif storage_type == "KnowledgeGraph": elif storage_type == "KnowledgeGraph":
if not storage_config.graph: if not storage_config.graph: