diff --git a/libs/community/langchain_community/vectorstores/tencentvectordb.py b/libs/community/langchain_community/vectorstores/tencentvectordb.py index c3bda890fe2..7d408c21132 100644 --- a/libs/community/langchain_community/vectorstores/tencentvectordb.py +++ b/libs/community/langchain_community/vectorstores/tencentvectordb.py @@ -6,7 +6,18 @@ import json import logging import time from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) import numpy as np from langchain_core.documents import Document @@ -168,8 +179,8 @@ class TencentVectorDB(VectorStore): tcvectordb = guard_import("tcvectordb") tcollection = guard_import("tcvectordb.model.collection") enum = guard_import("tcvectordb.model.enum") - - if t_vdb_embedding: + self.embedding_model = None + if embedding is None and t_vdb_embedding: embedding_model = [ model for model in enum.EmbeddingModel @@ -566,3 +577,17 @@ class TencentVectorDB(VectorStore): ) # Reorder the values and return. return [documents[x] for x in new_ordering if x != -1] + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + metric_type = self.index_params.metric_type + if metric_type == "COSINE": + return self._cosine_relevance_score_fn + elif metric_type == "L2": + return self._euclidean_relevance_score_fn + elif metric_type == "IP": + return self._max_inner_product_relevance_score_fn + else: + raise ValueError( + "No supported normalization function" + f" for distance metric of type: {metric_type}." + )