mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
Harrison/tfidf parameters (#3481)
Co-authored-by: pao <go5kuramubon@gmail.com> Co-authored-by: KyoHattori <kyo.hattori@abejainc.com>
This commit is contained in:
parent
eda69b13f3
commit
7257f9e015
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Largely based on
|
Largely based on
|
||||||
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
|
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
|
||||||
from typing import Any, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -21,10 +21,16 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(cls, texts: List[str], **kwargs: Any) -> "TFIDFRetriever":
|
def from_texts(
|
||||||
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
tfidf_params: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> "TFIDFRetriever":
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
|
||||||
vectorizer = TfidfVectorizer()
|
tfidf_params = tfidf_params or {}
|
||||||
|
vectorizer = TfidfVectorizer(**tfidf_params)
|
||||||
tfidf_array = vectorizer.fit_transform(texts)
|
tfidf_array = vectorizer.fit_transform(texts)
|
||||||
docs = [Document(page_content=t) for t in texts]
|
docs = [Document(page_content=t) for t in texts]
|
||||||
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
|
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
|
||||||
|
17
tests/integration_tests/retrievers/test_tfidf.py
Normal file
17
tests/integration_tests/retrievers/test_tfidf.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from langchain.retrievers.tfidf import TFIDFRetriever
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_texts() -> None:
|
||||||
|
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
|
||||||
|
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
|
||||||
|
assert len(tfidf_retriever.docs) == 3
|
||||||
|
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_texts_with_tfidf_params() -> None:
|
||||||
|
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
|
||||||
|
tfidf_retriever = TFIDFRetriever.from_texts(
|
||||||
|
texts=input_texts, tfidf_params={"min_df": 2}
|
||||||
|
)
|
||||||
|
# should count only multiple words (have, pan)
|
||||||
|
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 2)
|
Loading…
Reference in New Issue
Block a user