From 55db737302af5c27b3ee1e89ce1c9f520feb9ce1 Mon Sep 17 00:00:00 2001 From: miri-bar <160584887+miri-bar@users.noreply.github.com> Date: Tue, 26 Mar 2024 03:39:37 +0200 Subject: [PATCH] ai21[minor]: AI21 Labs Semantic Text Splitter support (#19510) Description: Added support for AI21 Labs model - Segmentation, as a Text Splitter Dependencies: ai21, langchain-text-splitter Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- .../ai21_semantic_text_splitter.ipynb | 466 ++++++++++++++++++ .../document_transformers/index.mdx | 1 + libs/partners/ai21/README.md | 16 + libs/partners/ai21/langchain_ai21/__init__.py | 2 + .../langchain_ai21/semantic_text_splitter.py | 158 ++++++ libs/partners/ai21/poetry.lock | 46 +- libs/partners/ai21/pyproject.toml | 1 + .../test_semantic_text_splitter.py | 130 +++++ .../ai21/tests/unit_tests/conftest.py | 41 +- .../ai21/tests/unit_tests/test_imports.py | 1 + .../unit_tests/test_semantic_text_splitter.py | 129 +++++ 11 files changed, 976 insertions(+), 15 deletions(-) create mode 100644 docs/docs/integrations/document_transformers/ai21_semantic_text_splitter.ipynb create mode 100644 libs/partners/ai21/langchain_ai21/semantic_text_splitter.py create mode 100644 libs/partners/ai21/tests/integration_tests/test_semantic_text_splitter.py create mode 100644 libs/partners/ai21/tests/unit_tests/test_semantic_text_splitter.py diff --git a/docs/docs/integrations/document_transformers/ai21_semantic_text_splitter.ipynb b/docs/docs/integrations/document_transformers/ai21_semantic_text_splitter.ipynb new file mode 100644 index 00000000000..ce67c73809e --- /dev/null +++ b/docs/docs/integrations/document_transformers/ai21_semantic_text_splitter.ipynb @@ -0,0 +1,466 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b9bba344bbe0b4bd", + "metadata": { + "collapsed": false + }, + "source": [ + "# AI21SemanticTextSplitter\n", + "\n", + "This example goes over how to use AI21SemanticTextSplitter in LangChain." + ] + }, + { + "cell_type": "markdown", + "id": "d8e4cdb63fbc34ec", + "metadata": { + "collapsed": false + }, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b09bb1cd2c7e036a", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "pip install langchain-ai21" + ] + }, + { + "cell_type": "markdown", + "id": "ba1d80fe8d82be89", + "metadata": { + "collapsed": false + }, + "source": [ + "## Environment Setup\n", + "\n", + "We'll need to get a AI21 API key and set the AI21_API_KEY environment variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "844b8f744d22bcb6", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "os.environ[\"AI21_API_KEY\"] = getpass()" + ] + }, + { + "cell_type": "markdown", + "id": "3e670b278e6b2b9e", + "metadata": { + "collapsed": false + }, + "source": [ + "## Example Usages" + ] + }, + { + "cell_type": "markdown", + "id": "f61c5c981f01ad31", + "metadata": { + "collapsed": false + }, + "source": [ + "### Splitting text by semantic meaning" + ] + }, + { + "cell_type": "markdown", + "id": "e7da988112712cf3", + "metadata": { + "collapsed": false + }, + "source": [ + "This example shows how to use AI21SemanticTextSplitter to split a text into chunks based on semantic meaning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d82b65c9b8684f3", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain_ai21 import AI21SemanticTextSplitter\n", + "\n", + "TEXT = (\n", + " \"We’ve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n", + " \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n", + " \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n", + " \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n", + " \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n", + " \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n", + " \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n", + " \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n", + " \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n", + " \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n", + " \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n", + " \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n", + " \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n", + " 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n", + " \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n", + " \"and you can read more about it here.\"\n", + ")\n", + "\n", + "semantic_text_splitter = AI21SemanticTextSplitter()\n", + "chunks = semantic_text_splitter.split_text(TEXT)\n", + "\n", + "print(f\"The text has been split into {len(chunks)} chunks.\")\n", + "for chunk in chunks:\n", + " print(chunk)\n", + " print(\"====\")" + ] + }, + { + "cell_type": "markdown", + "id": "2e8d1fcf818a8a81", + "metadata": { + "collapsed": false + }, + "source": [ + "### Splitting text by semantic meaning with merge" + ] + }, + { + "cell_type": "markdown", + "id": "c307abbc216fe89f", + "metadata": { + "collapsed": false + }, + "source": [ + "This example shows how to use AI21SemanticTextSplitter to split a text into chunks based on semantic meaning, then merging the chunks based on `chunk_size`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5651c581fcc1ff02", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain_ai21 import AI21SemanticTextSplitter\n", + "\n", + "TEXT = (\n", + " \"We’ve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n", + " \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n", + " \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n", + " \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n", + " \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n", + " \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n", + " \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n", + " \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n", + " \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n", + " \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n", + " \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n", + " \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n", + " \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n", + " 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n", + " \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n", + " \"and you can read more about it here.\"\n", + ")\n", + "\n", + "semantic_text_splitter_chunks = AI21SemanticTextSplitter(chunk_size=1000)\n", + "chunks = semantic_text_splitter_chunks.split_text(TEXT)\n", + "\n", + "print(f\"The text has been split into {len(chunks)} chunks.\")\n", + "for chunk in chunks:\n", + " print(chunk)\n", + " print(\"====\")" + ] + }, + { + "cell_type": "markdown", + "id": "b464db855e547cbb", + "metadata": { + "collapsed": false + }, + "source": [ + "### Splitting text to documents" + ] + }, + { + "cell_type": "markdown", + "id": "4410e8467012b193", + "metadata": { + "collapsed": false + }, + "source": [ + "This example shows how to use AI21SemanticTextSplitter to split a text into Documents based on semantic meaning. The metadata will contain a type for each document." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cf131d9be910115", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain_ai21 import AI21SemanticTextSplitter\n", + "\n", + "TEXT = (\n", + " \"We’ve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n", + " \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n", + " \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n", + " \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n", + " \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n", + " \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n", + " \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n", + " \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n", + " \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n", + " \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n", + " \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n", + " \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n", + " \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n", + " 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n", + " \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n", + " \"and you can read more about it here.\"\n", + ")\n", + "\n", + "semantic_text_splitter = AI21SemanticTextSplitter()\n", + "documents = semantic_text_splitter.split_text_to_documents(TEXT)\n", + "\n", + "print(f\"The text has been split into {len(documents)} Documents.\")\n", + "for doc in documents:\n", + " print(f\"type: {doc.metadata['source_type']}\")\n", + " print(f\"text: {doc.page_content}\")\n", + " print(\"====\")" + ] + }, + { + "cell_type": "markdown", + "id": "b544ba21335d01a6", + "metadata": { + "collapsed": false + }, + "source": [ + "### Creating Documents with Metadata" + ] + }, + { + "cell_type": "markdown", + "id": "c67f8c3ad89b8ad2", + "metadata": { + "collapsed": false + }, + "source": [ + "This example shows how to use AI21SemanticTextSplitter to create Documents from texts, and adding custom Metadata to each Document." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe222d0e85249bda", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain_ai21 import AI21SemanticTextSplitter\n", + "\n", + "TEXT = (\n", + " \"We’ve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n", + " \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n", + " \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n", + " \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n", + " \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n", + " \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n", + " \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n", + " \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n", + " \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n", + " \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n", + " \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n", + " \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n", + " \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n", + " 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n", + " \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n", + " \"and you can read more about it here.\"\n", + ")\n", + "\n", + "semantic_text_splitter = AI21SemanticTextSplitter()\n", + "texts = [TEXT]\n", + "documents = semantic_text_splitter.create_documents(\n", + " texts=texts, metadatas=[{\"pikachu\": \"pika pika\"}]\n", + ")\n", + "\n", + "print(f\"The text has been split into {len(documents)} Documents.\")\n", + "for doc in documents:\n", + " print(f\"metadata: {doc.metadata}\")\n", + " print(f\"text: {doc.page_content}\")\n", + " print(\"====\")" + ] + }, + { + "cell_type": "markdown", + "id": "f8b5682c34142319", + "metadata": { + "collapsed": false + }, + "source": [ + "### Splitting text to documents with start index" + ] + }, + { + "cell_type": "markdown", + "id": "359ea797c03ece85", + "metadata": { + "collapsed": false + }, + "source": [ + "This example shows how to use AI21SemanticTextSplitter to split a text into Documents based on semantic meaning. The metadata will contain a start index for each document.\n", + "**Note** that the start index provides an indication of the order of the chunks rather than the actual start index for each chunk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2dc39002f0c25784", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain_ai21 import AI21SemanticTextSplitter\n", + "\n", + "TEXT = (\n", + " \"We’ve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n", + " \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n", + " \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n", + " \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n", + " \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n", + " \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n", + " \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n", + " \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n", + " \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n", + " \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n", + " \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n", + " \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n", + " \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n", + " 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n", + " \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n", + " \"and you can read more about it here.\"\n", + ")\n", + "\n", + "semantic_text_splitter = AI21SemanticTextSplitter(add_start_index=True)\n", + "documents = semantic_text_splitter.create_documents(texts=[TEXT])\n", + "print(f\"The text has been split into {len(documents)} Documents.\")\n", + "for doc in documents:\n", + " print(f\"start_index: {doc.metadata['start_index']}\")\n", + " print(f\"text: {doc.page_content}\")\n", + " print(\"====\")" + ] + }, + { + "cell_type": "markdown", + "id": "b62939cc5803b9fb", + "metadata": { + "collapsed": false + }, + "source": [ + "### Splitting documents" + ] + }, + { + "cell_type": "markdown", + "id": "44162d340c0de5fb", + "metadata": { + "collapsed": false + }, + "source": [ + "This example shows how to use AI21SemanticTextSplitter to split a list of Documents into chunks based on semantic meaning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8950c8e4e1208bf6", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain_ai21 import AI21SemanticTextSplitter\n", + "from langchain_core.documents import Document\n", + "\n", + "TEXT = (\n", + " \"We’ve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n", + " \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n", + " \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n", + " \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n", + " \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n", + " \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n", + " \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n", + " \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n", + " \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n", + " \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n", + " \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n", + " \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n", + " \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n", + " 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n", + " \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n", + " \"and you can read more about it here.\"\n", + ")\n", + "\n", + "semantic_text_splitter = AI21SemanticTextSplitter()\n", + "document = Document(page_content=TEXT, metadata={\"hello\": \"goodbye\"})\n", + "documents = semantic_text_splitter.split_documents([document])\n", + "print(f\"The document list has been split into {len(documents)} Documents.\")\n", + "for doc in documents:\n", + " print(f\"text: {doc.page_content}\")\n", + " print(f\"metadata: {doc.metadata}\")\n", + " print(\"====\")" + ] + }, + { + "cell_type": "markdown", + "id": "f8f911b8d9ec22e5", + "metadata": { + "collapsed": false + }, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/modules/data_connection/document_transformers/index.mdx b/docs/docs/modules/data_connection/document_transformers/index.mdx index d0e5ee20636..d350ea7e79e 100644 --- a/docs/docs/modules/data_connection/document_transformers/index.mdx +++ b/docs/docs/modules/data_connection/document_transformers/index.mdx @@ -44,6 +44,7 @@ LangChain offers many different types of text splitters. These all live in the ` | Token | Tokens | | Splits text on tokens. There exist a few different ways to measure tokens. | | Character | A user defined character | | Splits text based on a user defined character. One of the simpler methods. | | [Experimental] Semantic Chunker | Sentences | | First splits on sentences. Then combines ones next to each other if they are semantically similar enough. Taken from [Greg Kamradt](https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/5_Levels_Of_Text_Splitting.ipynb) | +| [AI21 Semantic Text Splitter](/docs/integrations/document_transformers/ai21_semantic_text_splitter) | Semantics | ✅ | Identifies distinct topics that form coherent pieces of text and splits along those. | ## Evaluate text splitters diff --git a/libs/partners/ai21/README.md b/libs/partners/ai21/README.md index d4001791be4..508810a81ad 100644 --- a/libs/partners/ai21/README.md +++ b/libs/partners/ai21/README.md @@ -102,4 +102,20 @@ chain = tsm | StrOutputParser() response = chain.invoke( {"context": "Your context", "question": "Your question"}, ) +``` + +## Text Splitters + +### Semantic Text Splitter + +You can use AI21's semantic text splitter to split a text into segments. +Instead of merely using punctuation and newlines to divide the text, it identifies distinct topics that will work well together and will form a coherent piece of text. + +For a list for examples, see [this page](https://github.com/langchain-ai/langchain/blob/master/docs/docs/modules/data_connection/document_transformers/semantic_text_splitter.ipynb). + +```python +from langchain_ai21 import AI21SemanticTextSplitter + +splitter = AI21SemanticTextSplitter() +response = splitter.split_text("Your text") ``` \ No newline at end of file diff --git a/libs/partners/ai21/langchain_ai21/__init__.py b/libs/partners/ai21/langchain_ai21/__init__.py index ca8ae8e220f..fd766fb075c 100644 --- a/libs/partners/ai21/langchain_ai21/__init__.py +++ b/libs/partners/ai21/langchain_ai21/__init__.py @@ -2,10 +2,12 @@ from langchain_ai21.chat_models import ChatAI21 from langchain_ai21.contextual_answers import AI21ContextualAnswers from langchain_ai21.embeddings import AI21Embeddings from langchain_ai21.llms import AI21LLM +from langchain_ai21.semantic_text_splitter import AI21SemanticTextSplitter __all__ = [ "AI21LLM", "ChatAI21", "AI21Embeddings", "AI21ContextualAnswers", + "AI21SemanticTextSplitter", ] diff --git a/libs/partners/ai21/langchain_ai21/semantic_text_splitter.py b/libs/partners/ai21/langchain_ai21/semantic_text_splitter.py new file mode 100644 index 00000000000..cda0bba1b1f --- /dev/null +++ b/libs/partners/ai21/langchain_ai21/semantic_text_splitter.py @@ -0,0 +1,158 @@ +import copy +import logging +import re +from typing import ( + Any, + Iterable, + List, + Optional, +) + +from ai21.models import DocumentType +from langchain_core.documents import Document +from langchain_core.pydantic_v1 import SecretStr +from langchain_text_splitters import TextSplitter + +from langchain_ai21.ai21_base import AI21Base + +logger = logging.getLogger(__name__) + + +class AI21SemanticTextSplitter(TextSplitter): + """Splitting text into coherent and readable units, + based on distinct topics and lines + """ + + def __init__( + self, + chunk_size: int = 0, + chunk_overlap: int = 0, + client: Optional[Any] = None, + api_key: Optional[SecretStr] = None, + api_host: Optional[str] = None, + timeout_sec: Optional[float] = None, + num_retries: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + **kwargs, + ) + + self._segmentation = AI21Base( + client=client, + api_key=api_key, + api_host=api_host, + timeout_sec=timeout_sec, + num_retries=num_retries, + ).client.segmentation + + def split_text(self, source: str) -> List[str]: + """Split text into multiple components. + + Args: + source: Specifies the text input for text segmentation + """ + response = self._segmentation.create( + source=source, source_type=DocumentType.TEXT + ) + + segments = [segment.segment_text for segment in response.segments] + + if self._chunk_size > 0: + return self._merge_splits_no_seperator(segments) + + return segments + + def split_text_to_documents(self, source: str) -> List[Document]: + """Split text into multiple documents. + + Args: + source: Specifies the text input for text segmentation + """ + response = self._segmentation.create( + source=source, source_type=DocumentType.TEXT + ) + + return [ + Document( + page_content=segment.segment_text, + metadata={"source_type": segment.segment_type}, + ) + for segment in response.segments + ] + + def create_documents( + self, texts: List[str], metadatas: Optional[List[dict]] = None + ) -> List[Document]: + """Create documents from a list of texts.""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + + for i, text in enumerate(texts): + normalized_text = self._normalized_text(text) + index = 0 + previous_chunk_len = 0 + + for chunk in self.split_text_to_documents(text): + # merge metadata from user (if exists) and from segmentation api + metadata = copy.deepcopy(_metadatas[i]) + metadata.update(chunk.metadata) + + if self._add_start_index: + # find the start index of the chunk + offset = index + previous_chunk_len - self._chunk_overlap + normalized_chunk = self._normalized_text(chunk.page_content) + index = normalized_text.find(normalized_chunk, max(0, offset)) + metadata["start_index"] = index + previous_chunk_len = len(normalized_chunk) + + documents.append( + Document( + page_content=chunk.page_content, + metadata=metadata, + ) + ) + + return documents + + def _normalized_text(self, string: str) -> str: + """Use regular expression to replace sequences of '\n'""" + return re.sub(r"\s+", "", string) + + def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: + """This method overrides the default implementation of TextSplitter""" + return self._merge_splits_no_seperator(splits) + + def _merge_splits_no_seperator(self, splits: Iterable[str]) -> List[str]: + """Merge splits into chunks. + If the segment size is bigger than chunk_size, + it will be left as is (won't be cut to match to chunk_size). + If the segment size is smaller than chunk_size, + it will be merged with the next segment until the chunk_size is reached. + """ + chunks = [] + current_chunk = "" + + for split in splits: + split_len = self._length_function(split) + + if split_len > self._chunk_size: + logger.warning( + f"Split of length {split_len}" + f"exceeds chunk size {self._chunk_size}." + ) + + if self._length_function(current_chunk) + split_len > self._chunk_size: + if current_chunk != "": + chunks.append(current_chunk) + current_chunk = "" + + current_chunk += split + + if current_chunk != "": + chunks.append(current_chunk) + + return chunks diff --git a/libs/partners/ai21/poetry.lock b/libs/partners/ai21/poetry.lock index c6bd28c0898..93518f41535 100644 --- a/libs/partners/ai21/poetry.lock +++ b/libs/partners/ai21/poetry.lock @@ -300,7 +300,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.30" +version = "0.1.33" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -324,15 +324,32 @@ extended-testing = ["jinja2 (>=3,<4)"] type = "directory" url = "../../core" +[[package]] +name = "langchain-text-splitters" +version = "0.0.1" +description = "LangChain text splitting utilities" +optional = false +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langchain_text_splitters-0.0.1-py3-none-any.whl", hash = "sha256:f5b802f873f5ff6a8b9259ff34d53ed989666ef4e1582e6d1adb3b5520e3839a"}, + {file = "langchain_text_splitters-0.0.1.tar.gz", hash = "sha256:ac459fa98799f5117ad5425a9330b21961321e30bc19a2a2f9f761ddadd62aa1"}, +] + +[package.dependencies] +langchain-core = ">=0.1.28,<0.2.0" + +[package.extras] +extended-testing = ["lxml (>=5.1.0,<6.0.0)"] + [[package]] name = "langsmith" -version = "0.1.10" +version = "0.1.23" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.1.10-py3-none-any.whl", hash = "sha256:2997a80aea60ed235d83502a7ccdc1f62ffb4dd6b3b7dd4218e8fa4de68a6725"}, - {file = "langsmith-0.1.10.tar.gz", hash = "sha256:13e7e8b52e694aa4003370cefbb9e79cce3540c65dbf1517902bf7aa4dbbb653"}, + {file = "langsmith-0.1.23-py3-none-any.whl", hash = "sha256:69984268b9867cb31b875965b3f86b6f56ba17dd5454d487d3a1a999bdaeea69"}, + {file = "langsmith-0.1.23.tar.gz", hash = "sha256:327c66ec0de8c1bc57bfa47bbc70a29ef749e97c3e5571b9baf754d1e0644220"}, ] [package.dependencies] @@ -342,13 +359,13 @@ requests = ">=2,<3" [[package]] name = "marshmallow" -version = "3.21.0" +version = "3.21.1" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false python-versions = ">=3.8" files = [ - {file = "marshmallow-3.21.0-py3-none-any.whl", hash = "sha256:e7997f83571c7fd476042c2c188e4ee8a78900ca5e74bd9c8097afa56624e9bd"}, - {file = "marshmallow-3.21.0.tar.gz", hash = "sha256:20f53be28c6e374a711a16165fb22a8dc6003e3f7cda1285e3ca777b9193885b"}, + {file = "marshmallow-3.21.1-py3-none-any.whl", hash = "sha256:f085493f79efb0644f270a9bf2892843142d80d7174bbbd2f3713f2a589dc633"}, + {file = "marshmallow-3.21.1.tar.gz", hash = "sha256:4e65e9e0d80fc9e609574b9983cf32579f305c718afb30d7233ab818571768c3"}, ] [package.dependencies] @@ -507,13 +524,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pydantic" -version = "2.6.3" +version = "2.6.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"}, - {file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"}, + {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"}, + {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"}, ] [package.dependencies] @@ -689,13 +706,13 @@ watchdog = ">=2.0.0" [[package]] name = "python-dateutil" -version = "2.8.2" +version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] [package.dependencies] @@ -726,6 +743,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1009,4 +1027,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "3073522be06765f2acb7efea6ed1fcc49eaa05e82534d96fa914899dbbbb541f" +content-hash = "6ba91e0cf81e177c01efe980cbeedc2fe5a267599ce91c15acbcf2cd34df33dc" diff --git a/libs/partners/ai21/pyproject.toml b/libs/partners/ai21/pyproject.toml index 31df1cd978d..6b77ee5286e 100644 --- a/libs/partners/ai21/pyproject.toml +++ b/libs/partners/ai21/pyproject.toml @@ -8,6 +8,7 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.22" +langchain-text-splitters = "^0.0.1" ai21 = "^2.1.2" [tool.poetry.group.test] diff --git a/libs/partners/ai21/tests/integration_tests/test_semantic_text_splitter.py b/libs/partners/ai21/tests/integration_tests/test_semantic_text_splitter.py new file mode 100644 index 00000000000..ec3aebd3b10 --- /dev/null +++ b/libs/partners/ai21/tests/integration_tests/test_semantic_text_splitter.py @@ -0,0 +1,130 @@ +from ai21 import AI21Client +from langchain_core.documents import Document + +from langchain_ai21 import AI21SemanticTextSplitter + +TEXT = ( + "The original full name of the franchise is Pocket Monsters (ポケットモンスター, " + "Poketto Monsutā), which was abbreviated to " + "Pokemon during development of the original games.\n" + "When the franchise was released internationally, the short form of the title was " + "used, with an acute accent (´) " + "over the e to aid in pronunciation.\n" + "Pokémon refers to both the franchise itself and the creatures within its " + "fictional universe.\n" + "As a noun, it is identical in both the singular and plural, as is every " + "individual species name;[10] it is " + 'grammatically correct to say "one Pokémon" and "many Pokémon", as well ' + 'as "one Pikachu" and "many Pikachu".\n' + "In English, Pokémon may be pronounced either /'powkɛmon/ (poe-keh-mon) or " + "/'powkɪmon/ (poe-key-mon).\n" + "The Pokémon franchise is set in a world in which humans coexist with creatures " + "known as Pokémon.\n" + "Pokémon Red and Blue contain 151 Pokémon species, with new ones being introduced " + "in subsequent games; as of December 2023, 1,025 Pokémon species have been " + "introduced.\n[b] Most Pokémon are inspired by real-world animals;[12] for example," + "Pikachu are a yellow mouse-like species[13] with lightning bolt-shaped tails[14] " + "that possess electrical abilities.[15]\nThe player character takes the role of a " + "Pokémon Trainer.\nThe Trainer has three primary goals: travel and explore the " + "Pokémon world; discover and catch each Pokémon species in order to complete their" + "Pokédex; and train a team of up to six Pokémon at a time and have them engage " + "in battles.\nMost Pokémon can be caught with spherical devices known as Poké " + "Balls.\nOnce the opposing Pokémon is sufficiently weakened, the Trainer throws " + "the Poké Ball against the Pokémon, which is then transformed into a form of " + "energy and transported into the device.\nOnce the catch is successful, " + "the Pokémon is tamed and is under the Trainer's command from then on.\n" + "If the Poké Ball is thrown again, the Pokémon re-materializes into its " + "original state.\nThe Trainer's Pokémon can engage in battles against opposing " + "Pokémon, including those in the wild or owned by other Trainers.\nBecause the " + "franchise is aimed at children, these battles are never presented as overtly " + "violent and contain no blood or gore.[I]\nPokémon never die in battle, instead " + "fainting upon being defeated.[20][21][22]\nAfter a Pokémon wins a battle, it " + "gains experience and becomes stronger.[23] After gaining a certain amount of " + "experience points, its level increases, as well as one or more of its " + "statistics.\nAs its level increases, the Pokémon can learn new offensive " + "and defensive moves to use in battle.[24][25] Furthermore, many species can " + "undergo a form of spontaneous metamorphosis called Pokémon evolution, and " + "transform into stronger forms.[26] Most Pokémon will evolve at a certain level, " + "while others evolve through different means, such as exposure to a certain " + "item.[27]\n" +) + + +def test_split_text_to_document() -> None: + segmentation = AI21SemanticTextSplitter() + segments = segmentation.split_text_to_documents(source=TEXT) + assert len(segments) > 0 + for segment in segments: + assert segment.page_content is not None + assert segment.metadata is not None + + +def test_split_text() -> None: + segmentation = AI21SemanticTextSplitter() + segments = segmentation.split_text(source=TEXT) + assert len(segments) > 0 + + +def test_split_text__when_chunk_size_is_large__should_merge_segments() -> None: + segmentation_no_merge = AI21SemanticTextSplitter() + segments_no_merge = segmentation_no_merge.split_text(source=TEXT) + segmentation_merge = AI21SemanticTextSplitter(chunk_size=1000) + segments_merge = segmentation_merge.split_text(source=TEXT) + # Assert that a merge did happen + assert len(segments_no_merge) > len(segments_merge) + reconstructed_text_merged = "".join(segments_merge) + reconstructed_text_non_merged = "".join(segments_no_merge) + # Assert that the merge did not change the content + assert reconstructed_text_merged == reconstructed_text_non_merged + + +def test_split_text__chunk_size_is_too_small__should_return_non_merged_segments() -> ( + None +): + segmentation_no_merge = AI21SemanticTextSplitter() + segments_no_merge = segmentation_no_merge.split_text(source=TEXT) + segmentation_merge = AI21SemanticTextSplitter(chunk_size=10) + segments_merge = segmentation_merge.split_text(source=TEXT) + # Assert that a merge did happen + assert len(segments_no_merge) == len(segments_merge) + reconstructed_text_merged = "".join(segments_merge) + reconstructed_text_non_merged = "".join(segments_no_merge) + # Assert that the merge did not change the content + assert reconstructed_text_merged == reconstructed_text_non_merged + + +def test_split_text__when_chunk_size_set_with_ai21_tokenizer() -> None: + segmentation_no_merge = AI21SemanticTextSplitter( + length_function=AI21Client().count_tokens + ) + segments_no_merge = segmentation_no_merge.split_text(source=TEXT) + segmentation_merge = AI21SemanticTextSplitter( + chunk_size=1000, length_function=AI21Client().count_tokens + ) + segments_merge = segmentation_merge.split_text(source=TEXT) + # Assert that a merge did happen + assert len(segments_no_merge) > len(segments_merge) + reconstructed_text_merged = "".join(segments_merge) + reconstructed_text_non_merged = "".join(segments_no_merge) + # Assert that the merge did not change the content + assert reconstructed_text_merged == reconstructed_text_non_merged + + +def test_create_documents() -> None: + texts = [TEXT] + segmentation = AI21SemanticTextSplitter() + documents = segmentation.create_documents(texts=texts) + assert len(documents) > 0 + for document in documents: + assert document.page_content is not None + assert document.metadata is not None + + +def test_split_documents() -> None: + documents = [Document(page_content=TEXT, metadata={"foo": "bar"})] + segmentation = AI21SemanticTextSplitter() + segments = segmentation.split_documents(documents=documents) + assert len(segments) > 0 + for segment in segments: + assert segment.page_content is not None + assert segment.metadata is not None diff --git a/libs/partners/ai21/tests/unit_tests/conftest.py b/libs/partners/ai21/tests/unit_tests/conftest.py index 858417a7da7..d8b032ac6e2 100644 --- a/libs/partners/ai21/tests/unit_tests/conftest.py +++ b/libs/partners/ai21/tests/unit_tests/conftest.py @@ -16,12 +16,13 @@ from ai21.models import ( FinishReason, Penalty, RoleType, + SegmentationResponse, ) +from ai21.models.responses.segmentation_response import Segment from pytest_mock import MockerFixture DUMMY_API_KEY = "test_api_key" - BASIC_EXAMPLE_LLM_PARAMETERS = { "num_results": 3, "max_tokens": 20, @@ -38,6 +39,32 @@ BASIC_EXAMPLE_LLM_PARAMETERS = { ), } +SEGMENTS = [ + Segment( + segment_type="normal_text", + segment_text=( + "The original full name of the franchise is Pocket Monsters " + "(ポケットモンスター, Poketto Monsutā), which was abbreviated to " + "Pokemon during development of the original games.\n\nWhen the " + "franchise was released internationally, the short form of the " + "title was used, with an acute accent (´) over the e to aid " + "in pronunciation." + ), + ), + Segment( + segment_type="normal_text", + segment_text=( + "Pokémon refers to both the franchise itself and the creatures " + "within its fictional universe.\n\nAs a noun, it is identical in " + "both the singular and plural, as is every individual species " + 'name;[10] it is grammatically correct to say "one Pokémon" ' + 'and "many Pokémon", as well as "one Pikachu" and "many ' + 'Pikachu".\n\nIn English, Pokémon may be pronounced either ' + "/'powkɛmon/ (poe-keh-mon) or /'powkɪmon/ (poe-key-mon)." + ), + ), +] + BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = { "num_results": 3, @@ -125,3 +152,15 @@ def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock: ) return mock_client + + +@pytest.fixture +def mock_client_with_semantic_text_splitter(mocker: MockerFixture) -> Mock: + mock_client = mocker.MagicMock(spec=AI21Client) + mock_client.segmentation = mocker.MagicMock() + mock_client.segmentation.create.return_value = SegmentationResponse( + id="12345", + segments=SEGMENTS, + ) + + return mock_client diff --git a/libs/partners/ai21/tests/unit_tests/test_imports.py b/libs/partners/ai21/tests/unit_tests/test_imports.py index 92577219c38..0130415bbff 100644 --- a/libs/partners/ai21/tests/unit_tests/test_imports.py +++ b/libs/partners/ai21/tests/unit_tests/test_imports.py @@ -5,6 +5,7 @@ EXPECTED_ALL = [ "ChatAI21", "AI21Embeddings", "AI21ContextualAnswers", + "AI21SemanticTextSplitter", ] diff --git a/libs/partners/ai21/tests/unit_tests/test_semantic_text_splitter.py b/libs/partners/ai21/tests/unit_tests/test_semantic_text_splitter.py new file mode 100644 index 00000000000..1beb4fba4d8 --- /dev/null +++ b/libs/partners/ai21/tests/unit_tests/test_semantic_text_splitter.py @@ -0,0 +1,129 @@ +from unittest.mock import Mock + +import pytest + +from langchain_ai21 import AI21SemanticTextSplitter +from tests.unit_tests.conftest import SEGMENTS + +TEXT = ( + "The original full name of the franchise is Pocket Monsters (ポケットモンスター, " + "Poketto Monsutā), which was abbreviated to " + "Pokemon during development of the original games.\n" + "When the franchise was released internationally, the short form of the title was " + "used, with an acute accent (´) " + "over the e to aid in pronunciation.\n" + "Pokémon refers to both the franchise itself and the creatures within its " + "fictional universe.\n" + "As a noun, it is identical in both the singular and plural, as is every " + "individual species name;[10] it is " + 'grammatically correct to say "one Pokémon" and "many Pokémon", as well ' + 'as "one Pikachu" and "many Pikachu".\n' + "In English, Pokémon may be pronounced either /'powkɛmon/ (poe-keh-mon) or " + "/'powkɪmon/ (poe-key-mon).\n" + "The Pokémon franchise is set in a world in which humans coexist with creatures " + "known as Pokémon.\n" + "Pokémon Red and Blue contain 151 Pokémon species, with new ones being introduced " + "in subsequent games; as of December 2023, 1,025 Pokémon species have been " + "introduced.\n[b] Most Pokémon are inspired by real-world animals;[12] for example," + "Pikachu are a yellow mouse-like species[13] with lightning bolt-shaped tails[14] " + "that possess electrical abilities.[15]" +) + + +@pytest.mark.parametrize( + ids=[ + "when_chunk_size_is_zero", + "when_chunk_size_is_large", + "when_chunk_size_is_small", + ], + argnames=["chunk_size", "expected_segmentation_len"], + argvalues=[ + (0, 2), + (1000, 1), + (10, 2), + ], +) +def test_split_text__on_chunk_size( + chunk_size: int, + expected_segmentation_len: int, + mock_client_with_semantic_text_splitter: Mock, +) -> None: + sts = AI21SemanticTextSplitter( + chunk_size=chunk_size, + client=mock_client_with_semantic_text_splitter, + ) + segments = sts.split_text("This is a test") + assert len(segments) == expected_segmentation_len + + +def test_split_text__on_large_chunk_size__should_merge_chunks( + mock_client_with_semantic_text_splitter: Mock, +) -> None: + sts_no_merge = AI21SemanticTextSplitter( + client=mock_client_with_semantic_text_splitter + ) + sts_merge = AI21SemanticTextSplitter( + client=mock_client_with_semantic_text_splitter, + chunk_size=1000, + ) + segments_no_merge = sts_no_merge.split_text("This is a test") + segments_merge = sts_merge.split_text("This is a test") + assert len(segments_merge) > 0 + assert len(segments_no_merge) > 0 + assert len(segments_no_merge) > len(segments_merge) + + +def test_split_text__on_small_chunk_size__should_not_merge_chunks( + mock_client_with_semantic_text_splitter: Mock, +) -> None: + sts_no_merge = AI21SemanticTextSplitter( + client=mock_client_with_semantic_text_splitter + ) + segments = sts_no_merge.split_text("This is a test") + assert len(segments) == 2 + for index in range(2): + assert segments[index] == SEGMENTS[index].segment_text + + +def test_create_documents__on_start_index__should_should_add_start_index( + mock_client_with_semantic_text_splitter: Mock, +) -> None: + sts = AI21SemanticTextSplitter( + client=mock_client_with_semantic_text_splitter, + add_start_index=True, + ) + + response = sts.create_documents(texts=[TEXT]) + assert len(response) > 0 + for segment in response: + assert segment.page_content is not None + assert segment.metadata is not None + assert "start_index" in segment.metadata + assert segment.metadata["start_index"] > -1 + + +def test_create_documents__when_metadata_from_user__should_add_metadata( + mock_client_with_semantic_text_splitter: Mock, +) -> None: + sts = AI21SemanticTextSplitter(client=mock_client_with_semantic_text_splitter) + metadatas = [{"hello": "world"}] + response = sts.create_documents(texts=[TEXT], metadatas=metadatas) + assert len(response) > 0 + for index in range(len(response)): + assert response[index].page_content == SEGMENTS[index].segment_text + assert len(response[index].metadata) == 2 + assert response[index].metadata["source_type"] == SEGMENTS[index].segment_type + assert response[index].metadata["hello"] == "world" + + +def test_split_text_to_documents__when_metadata_not_passed__should_contain_source_type( + mock_client_with_semantic_text_splitter: Mock, +) -> None: + sts = AI21SemanticTextSplitter(client=mock_client_with_semantic_text_splitter) + response = sts.split_text_to_documents(TEXT) + assert len(response) > 0 + for segment in response: + assert segment.page_content is not None + assert segment.metadata is not None + assert "source_type" in segment.metadata + assert segment.metadata["source_type"] is not None