community: implement _select_relevance_score_fn for tencent vectordb (#28036)

implement _select_relevance_score_fn for tencent vectordb
fix use external embedding for tencent vectordb

Co-authored-by: wlleiiwang <wlleiiwang@tencent.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
wlleiiwang
2024-12-04 11:03:00 +08:00
committed by GitHub
parent d34bf78f3b
commit 6151ea78d5

View File

@@ -6,7 +6,18 @@ import json
import logging import logging
import time import time
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
import numpy as np import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
@@ -168,8 +179,8 @@ class TencentVectorDB(VectorStore):
tcvectordb = guard_import("tcvectordb") tcvectordb = guard_import("tcvectordb")
tcollection = guard_import("tcvectordb.model.collection") tcollection = guard_import("tcvectordb.model.collection")
enum = guard_import("tcvectordb.model.enum") enum = guard_import("tcvectordb.model.enum")
self.embedding_model = None
if t_vdb_embedding: if embedding is None and t_vdb_embedding:
embedding_model = [ embedding_model = [
model model
for model in enum.EmbeddingModel for model in enum.EmbeddingModel
@@ -566,3 +577,17 @@ class TencentVectorDB(VectorStore):
) )
# Reorder the values and return. # Reorder the values and return.
return [documents[x] for x in new_ordering if x != -1] return [documents[x] for x in new_ordering if x != -1]
def _select_relevance_score_fn(self) -> Callable[[float], float]:
metric_type = self.index_params.metric_type
if metric_type == "COSINE":
return self._cosine_relevance_score_fn
elif metric_type == "L2":
return self._euclidean_relevance_score_fn
elif metric_type == "IP":
return self._max_inner_product_relevance_score_fn
else:
raise ValueError(
"No supported normalization function"
f" for distance metric of type: {metric_type}."
)