diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index eb3ad5b0d02..198e8954f82 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -38,6 +38,7 @@ javelin-sdk>=0.1.8,<0.2 jinja2>=3,<4 jq>=1.4.1,<2 jsonschema>1 +keybert>=0.8.5 lxml>=4.9.3,<6.0 markdownify>=0.11.6,<0.12 motor>=3.3.1,<4 diff --git a/libs/community/langchain_community/graph_vectorstores/extractors/__init__.py b/libs/community/langchain_community/graph_vectorstores/extractors/__init__.py index 4f7d4036f97..a78eb3807be 100644 --- a/libs/community/langchain_community/graph_vectorstores/extractors/__init__.py +++ b/libs/community/langchain_community/graph_vectorstores/extractors/__init__.py @@ -10,6 +10,10 @@ from langchain_community.graph_vectorstores.extractors.html_link_extractor impor HtmlInput, HtmlLinkExtractor, ) +from langchain_community.graph_vectorstores.extractors.keybert_link_extractor import ( + KeybertInput, + KeybertLinkExtractor, +) from langchain_community.graph_vectorstores.extractors.link_extractor import ( LinkExtractor, ) @@ -24,6 +28,10 @@ __all__ = [ "HierarchyLinkExtractor", "HtmlInput", "HtmlLinkExtractor", + "KeybertInput", + "KeybertLinkExtractor", + "LinkExtractor", "LinkExtractor", "LinkExtractorAdapter", + "LinkExtractorAdapter", ] diff --git a/libs/community/langchain_community/graph_vectorstores/extractors/keybert_link_extractor.py b/libs/community/langchain_community/graph_vectorstores/extractors/keybert_link_extractor.py new file mode 100644 index 00000000000..aee7898a5c1 --- /dev/null +++ b/libs/community/langchain_community/graph_vectorstores/extractors/keybert_link_extractor.py @@ -0,0 +1,73 @@ +from typing import Any, Dict, Iterable, Optional, Set, Union + +from langchain_core.documents import Document +from langchain_core.graph_vectorstores.links import Link + +from langchain_community.graph_vectorstores.extractors.link_extractor import ( + LinkExtractor, +) + +KeybertInput = Union[str, Document] + + +class KeybertLinkExtractor(LinkExtractor[KeybertInput]): + def __init__( + self, + *, + kind: str = "kw", + embedding_model: str = "all-MiniLM-L6-v2", + extract_keywords_kwargs: Optional[Dict[str, Any]] = None, + ): + """Extract keywords using KeyBERT . + + Example: + + .. code-block:: python + + extractor = KeybertLinkExtractor() + + results = extractor.extract_one(PAGE_1) + + Args: + kind: Kind of links to produce with this extractor. + embedding_model: Name of the embedding model to use with KeyBERT. + extract_keywords_kwargs: Keyword arguments to pass to KeyBERT's + `extract_keywords` method. + """ + try: + import keybert + + self._kw_model = keybert.KeyBERT(model=embedding_model) + except ImportError: + raise ImportError( + "keybert is required for KeybertLinkExtractor. " + "Please install it with `pip install keybert`." + ) from None + + self._kind = kind + self._extract_keywords_kwargs = extract_keywords_kwargs or {} + + def extract_one(self, input: KeybertInput) -> Set[Link]: # noqa: A002 + keywords = self._kw_model.extract_keywords( + input if isinstance(input, str) else input.page_content, + **self._extract_keywords_kwargs, + ) + return {Link.bidir(kind=self._kind, tag=kw[0]) for kw in keywords} + + def extract_many( + self, + inputs: Iterable[KeybertInput], + ) -> Iterable[Set[Link]]: + inputs = list(inputs) + if len(inputs) == 1: + # Even though we pass a list, if it contains one item, keybert will + # flatten it. This means it's easier to just call the special case + # for one item. + yield self.extract_one(inputs[0]) + elif len(inputs) > 1: + strs = [i if isinstance(i, str) else i.page_content for i in inputs] + extracted = self._kw_model.extract_keywords( + strs, **self._extract_keywords_kwargs + ) + for keywords in extracted: + yield {Link.bidir(kind=self._kind, tag=kw[0]) for kw in keywords} diff --git a/libs/community/tests/integration_tests/graph_vectorstores/extractors/test_keybert_link_extractor.py b/libs/community/tests/integration_tests/graph_vectorstores/extractors/test_keybert_link_extractor.py new file mode 100644 index 00000000000..041fd37122c --- /dev/null +++ b/libs/community/tests/integration_tests/graph_vectorstores/extractors/test_keybert_link_extractor.py @@ -0,0 +1,64 @@ +import pytest +from langchain_core.graph_vectorstores.links import Link + +from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor + +PAGE_1 = """ +Supervised learning is the machine learning task of learning a function that +maps an input to an output based on example input-output pairs. It infers a +function from labeled training data consisting of a set of training examples. In +supervised learning, each example is a pair consisting of an input object +(typically a vector) and a desired output value (also called the supervisory +signal). A supervised learning algorithm analyzes the training data and produces +an inferred function, which can be used for mapping new examples. An optimal +scenario will allow for the algorithm to correctly determine the class labels +for unseen instances. This requires the learning algorithm to generalize from +the training data to unseen situations in a 'reasonable' way (see inductive +bias). +""" + +PAGE_2 = """ +KeyBERT is a minimal and easy-to-use keyword extraction technique that leverages +BERT embeddings to create keywords and keyphrases that are most similar to a +document. +""" + + +@pytest.mark.requires("keybert") +def test_one_from_keywords() -> None: + extractor = KeybertLinkExtractor() + + results = extractor.extract_one(PAGE_1) + assert results == { + Link.bidir(kind="kw", tag="supervised"), + Link.bidir(kind="kw", tag="labels"), + Link.bidir(kind="kw", tag="labeled"), + Link.bidir(kind="kw", tag="learning"), + Link.bidir(kind="kw", tag="training"), + } + + +@pytest.mark.requires("keybert") +def test_many_from_keyphrases() -> None: + extractor = KeybertLinkExtractor( + extract_keywords_kwargs={ + "keyphrase_ngram_range": (1, 2), + } + ) + + results = list(extractor.extract_many([PAGE_1, PAGE_2])) + assert results[0] == { + Link.bidir(kind="kw", tag="supervised"), + Link.bidir(kind="kw", tag="labeled training"), + Link.bidir(kind="kw", tag="supervised learning"), + Link.bidir(kind="kw", tag="examples supervised"), + Link.bidir(kind="kw", tag="signal supervised"), + } + + assert results[1] == { + Link.bidir(kind="kw", tag="keyphrases"), + Link.bidir(kind="kw", tag="keyword extraction"), + Link.bidir(kind="kw", tag="keybert"), + Link.bidir(kind="kw", tag="keywords keyphrases"), + Link.bidir(kind="kw", tag="keybert minimal"), + }