From 6151ea78d50e8617f2562d8eaec0dcdd29b2ee5c Mon Sep 17 00:00:00 2001 From: wlleiiwang <1025164922@qq.com> Date: Wed, 4 Dec 2024 11:03:00 +0800 Subject: [PATCH] community: implement _select_relevance_score_fn for tencent vectordb (#28036) implement _select_relevance_score_fn for tencent vectordb fix use external embedding for tencent vectordb Co-authored-by: wlleiiwang Co-authored-by: Erick Friis --- .../vectorstores/tencentvectordb.py | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/tencentvectordb.py b/libs/community/langchain_community/vectorstores/tencentvectordb.py index c3bda890fe2..7d408c21132 100644 --- a/libs/community/langchain_community/vectorstores/tencentvectordb.py +++ b/libs/community/langchain_community/vectorstores/tencentvectordb.py @@ -6,7 +6,18 @@ import json import logging import time from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) import numpy as np from langchain_core.documents import Document @@ -168,8 +179,8 @@ class TencentVectorDB(VectorStore): tcvectordb = guard_import("tcvectordb") tcollection = guard_import("tcvectordb.model.collection") enum = guard_import("tcvectordb.model.enum") - - if t_vdb_embedding: + self.embedding_model = None + if embedding is None and t_vdb_embedding: embedding_model = [ model for model in enum.EmbeddingModel @@ -566,3 +577,17 @@ class TencentVectorDB(VectorStore): ) # Reorder the values and return. return [documents[x] for x in new_ordering if x != -1] + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + metric_type = self.index_params.metric_type + if metric_type == "COSINE": + return self._cosine_relevance_score_fn + elif metric_type == "L2": + return self._euclidean_relevance_score_fn + elif metric_type == "IP": + return self._max_inner_product_relevance_score_fn + else: + raise ValueError( + "No supported normalization function" + f" for distance metric of type: {metric_type}." + )