add save and load tfidf vectorizer and docs for TFIDFRetriever (#8112)

This is to add save_local and load_local to tfidf_vectorizer and docs in
tfidf_retriever to make the vectorizer reusable.

<!-- Thank you for contributing to LangChain!

Replace this comment with:
- Description: add save_local and load_local to tfidf_vectorizer and
docs in tfidf_retriever
  - Issue: None
  - Dependencies: None
  - Tag maintainer: @rlancemartin, @eyurtsev
  - Twitter handle: @MlopsJ

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
shibuiwilliam
2023-08-04 15:06:27 +09:00
committed by GitHub
parent 0f68054401
commit 2759e2d857
3 changed files with 141 additions and 4 deletions

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
@@ -76,3 +78,49 @@ class TFIDFRetriever(BaseRetriever):
) # Op -- (n_docs,1) -- Cosine Sim with each doc
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs
def save_local(
self,
folder_path: str,
file_name: str = "tfidf_vectorizer",
) -> None:
try:
import joblib
except ImportError:
raise ImportError(
"Could not import joblib, please install with `pip install joblib`."
)
path = Path(folder_path)
path.mkdir(exist_ok=True, parents=True)
# Save vectorizer with joblib dump.
joblib.dump(self.vectorizer, path / f"{file_name}.joblib")
# Save docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "wb") as f:
pickle.dump((self.docs, self.tfidf_array), f)
@classmethod
def load_local(
cls,
folder_path: str,
file_name: str = "tfidf_vectorizer",
) -> TFIDFRetriever:
try:
import joblib
except ImportError:
raise ImportError(
"Could not import joblib, please install with `pip install joblib`."
)
path = Path(folder_path)
# Load vectorizer with joblib load.
vectorizer = joblib.load(path / f"{file_name}.joblib")
# Load docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "rb") as f:
docs, tfidf_array = pickle.load(f)
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array)

View File

@@ -1,3 +1,7 @@
import os
from datetime import datetime
from tempfile import TemporaryDirectory
import pytest
from langchain.retrievers.tfidf import TFIDFRetriever
@@ -32,3 +36,26 @@ def test_from_documents() -> None:
tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs)
assert len(tfidf_retriever.docs) == 3
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
@pytest.mark.requires("sklearn")
def test_save_local_load_local() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
file_name = "tfidf_vectorizer"
temp_timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
tfidf_retriever.save_local(
folder_path=temp_folder,
file_name=file_name,
)
assert os.path.exists(os.path.join(temp_folder, f"{file_name}.joblib"))
assert os.path.exists(os.path.join(temp_folder, f"{file_name}.pkl"))
loaded_tfidf_retriever = TFIDFRetriever.load_local(
folder_path=temp_folder,
file_name=file_name,
)
assert len(loaded_tfidf_retriever.docs) == 3
assert loaded_tfidf_retriever.tfidf_array.toarray().shape == (3, 5)