ai21[minor]: AI21 Labs Semantic Text Splitter support (#19510)

Description: Added support for AI21 Labs model - Segmentation, as a Text
Splitter
Dependencies: ai21, langchain-text-splitter
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
miri-bar 2024-03-26 03:39:37 +02:00 committed by GitHub
parent b2a11ce686
commit 55db737302
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 976 additions and 15 deletions

View File

@ -0,0 +1,466 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b9bba344bbe0b4bd",
"metadata": {
"collapsed": false
},
"source": [
"# AI21SemanticTextSplitter\n",
"\n",
"This example goes over how to use AI21SemanticTextSplitter in LangChain."
]
},
{
"cell_type": "markdown",
"id": "d8e4cdb63fbc34ec",
"metadata": {
"collapsed": false
},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b09bb1cd2c7e036a",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"pip install langchain-ai21"
]
},
{
"cell_type": "markdown",
"id": "ba1d80fe8d82be89",
"metadata": {
"collapsed": false
},
"source": [
"## Environment Setup\n",
"\n",
"We'll need to get a AI21 API key and set the AI21_API_KEY environment variable:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "844b8f744d22bcb6",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import os\n",
"from getpass import getpass\n",
"\n",
"os.environ[\"AI21_API_KEY\"] = getpass()"
]
},
{
"cell_type": "markdown",
"id": "3e670b278e6b2b9e",
"metadata": {
"collapsed": false
},
"source": [
"## Example Usages"
]
},
{
"cell_type": "markdown",
"id": "f61c5c981f01ad31",
"metadata": {
"collapsed": false
},
"source": [
"### Splitting text by semantic meaning"
]
},
{
"cell_type": "markdown",
"id": "e7da988112712cf3",
"metadata": {
"collapsed": false
},
"source": [
"This example shows how to use AI21SemanticTextSplitter to split a text into chunks based on semantic meaning."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d82b65c9b8684f3",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain_ai21 import AI21SemanticTextSplitter\n",
"\n",
"TEXT = (\n",
" \"Weve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n",
" \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n",
" \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n",
" \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n",
" \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n",
" \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n",
" \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n",
" \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n",
" \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n",
" \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n",
" \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n",
" \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n",
" \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n",
" 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n",
" \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n",
" \"and you can read more about it here.\"\n",
")\n",
"\n",
"semantic_text_splitter = AI21SemanticTextSplitter()\n",
"chunks = semantic_text_splitter.split_text(TEXT)\n",
"\n",
"print(f\"The text has been split into {len(chunks)} chunks.\")\n",
"for chunk in chunks:\n",
" print(chunk)\n",
" print(\"====\")"
]
},
{
"cell_type": "markdown",
"id": "2e8d1fcf818a8a81",
"metadata": {
"collapsed": false
},
"source": [
"### Splitting text by semantic meaning with merge"
]
},
{
"cell_type": "markdown",
"id": "c307abbc216fe89f",
"metadata": {
"collapsed": false
},
"source": [
"This example shows how to use AI21SemanticTextSplitter to split a text into chunks based on semantic meaning, then merging the chunks based on `chunk_size`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5651c581fcc1ff02",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain_ai21 import AI21SemanticTextSplitter\n",
"\n",
"TEXT = (\n",
" \"Weve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n",
" \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n",
" \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n",
" \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n",
" \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n",
" \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n",
" \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n",
" \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n",
" \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n",
" \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n",
" \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n",
" \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n",
" \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n",
" 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n",
" \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n",
" \"and you can read more about it here.\"\n",
")\n",
"\n",
"semantic_text_splitter_chunks = AI21SemanticTextSplitter(chunk_size=1000)\n",
"chunks = semantic_text_splitter_chunks.split_text(TEXT)\n",
"\n",
"print(f\"The text has been split into {len(chunks)} chunks.\")\n",
"for chunk in chunks:\n",
" print(chunk)\n",
" print(\"====\")"
]
},
{
"cell_type": "markdown",
"id": "b464db855e547cbb",
"metadata": {
"collapsed": false
},
"source": [
"### Splitting text to documents"
]
},
{
"cell_type": "markdown",
"id": "4410e8467012b193",
"metadata": {
"collapsed": false
},
"source": [
"This example shows how to use AI21SemanticTextSplitter to split a text into Documents based on semantic meaning. The metadata will contain a type for each document."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3cf131d9be910115",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain_ai21 import AI21SemanticTextSplitter\n",
"\n",
"TEXT = (\n",
" \"Weve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n",
" \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n",
" \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n",
" \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n",
" \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n",
" \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n",
" \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n",
" \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n",
" \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n",
" \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n",
" \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n",
" \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n",
" \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n",
" 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n",
" \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n",
" \"and you can read more about it here.\"\n",
")\n",
"\n",
"semantic_text_splitter = AI21SemanticTextSplitter()\n",
"documents = semantic_text_splitter.split_text_to_documents(TEXT)\n",
"\n",
"print(f\"The text has been split into {len(documents)} Documents.\")\n",
"for doc in documents:\n",
" print(f\"type: {doc.metadata['source_type']}\")\n",
" print(f\"text: {doc.page_content}\")\n",
" print(\"====\")"
]
},
{
"cell_type": "markdown",
"id": "b544ba21335d01a6",
"metadata": {
"collapsed": false
},
"source": [
"### Creating Documents with Metadata"
]
},
{
"cell_type": "markdown",
"id": "c67f8c3ad89b8ad2",
"metadata": {
"collapsed": false
},
"source": [
"This example shows how to use AI21SemanticTextSplitter to create Documents from texts, and adding custom Metadata to each Document."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe222d0e85249bda",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain_ai21 import AI21SemanticTextSplitter\n",
"\n",
"TEXT = (\n",
" \"Weve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n",
" \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n",
" \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n",
" \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n",
" \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n",
" \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n",
" \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n",
" \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n",
" \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n",
" \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n",
" \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n",
" \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n",
" \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n",
" 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n",
" \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n",
" \"and you can read more about it here.\"\n",
")\n",
"\n",
"semantic_text_splitter = AI21SemanticTextSplitter()\n",
"texts = [TEXT]\n",
"documents = semantic_text_splitter.create_documents(\n",
" texts=texts, metadatas=[{\"pikachu\": \"pika pika\"}]\n",
")\n",
"\n",
"print(f\"The text has been split into {len(documents)} Documents.\")\n",
"for doc in documents:\n",
" print(f\"metadata: {doc.metadata}\")\n",
" print(f\"text: {doc.page_content}\")\n",
" print(\"====\")"
]
},
{
"cell_type": "markdown",
"id": "f8b5682c34142319",
"metadata": {
"collapsed": false
},
"source": [
"### Splitting text to documents with start index"
]
},
{
"cell_type": "markdown",
"id": "359ea797c03ece85",
"metadata": {
"collapsed": false
},
"source": [
"This example shows how to use AI21SemanticTextSplitter to split a text into Documents based on semantic meaning. The metadata will contain a start index for each document.\n",
"**Note** that the start index provides an indication of the order of the chunks rather than the actual start index for each chunk."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2dc39002f0c25784",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain_ai21 import AI21SemanticTextSplitter\n",
"\n",
"TEXT = (\n",
" \"Weve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n",
" \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n",
" \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n",
" \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n",
" \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n",
" \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n",
" \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n",
" \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n",
" \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n",
" \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n",
" \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n",
" \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n",
" \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n",
" 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n",
" \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n",
" \"and you can read more about it here.\"\n",
")\n",
"\n",
"semantic_text_splitter = AI21SemanticTextSplitter(add_start_index=True)\n",
"documents = semantic_text_splitter.create_documents(texts=[TEXT])\n",
"print(f\"The text has been split into {len(documents)} Documents.\")\n",
"for doc in documents:\n",
" print(f\"start_index: {doc.metadata['start_index']}\")\n",
" print(f\"text: {doc.page_content}\")\n",
" print(\"====\")"
]
},
{
"cell_type": "markdown",
"id": "b62939cc5803b9fb",
"metadata": {
"collapsed": false
},
"source": [
"### Splitting documents"
]
},
{
"cell_type": "markdown",
"id": "44162d340c0de5fb",
"metadata": {
"collapsed": false
},
"source": [
"This example shows how to use AI21SemanticTextSplitter to split a list of Documents into chunks based on semantic meaning."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8950c8e4e1208bf6",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain_ai21 import AI21SemanticTextSplitter\n",
"from langchain_core.documents import Document\n",
"\n",
"TEXT = (\n",
" \"Weve all experienced reading long, tedious, and boring pieces of text - financial reports, \"\n",
" \"legal documents, or terms and conditions (though, who actually reads those terms and conditions to be honest?).\\n\"\n",
" \"Imagine a company that employs hundreds of thousands of employees. In today's information \"\n",
" \"overload age, nearly 30% of the workday is spent dealing with documents. There's no surprise \"\n",
" \"here, given that some of these documents are long and convoluted on purpose (did you know that \"\n",
" \"reading through all your privacy policies would take almost a quarter of a year?). Aside from \"\n",
" \"inefficiency, workers may simply refrain from reading some documents (for example, Only 16% of \"\n",
" \"Employees Read Their Employment Contracts Entirely Before Signing!).\\nThis is where AI-driven summarization \"\n",
" \"tools can be helpful: instead of reading entire documents, which is tedious and time-consuming, \"\n",
" \"users can (ideally) quickly extract relevant information from a text. With large language models, \"\n",
" \"the development of those tools is easier than ever, and you can offer your users a summary that is \"\n",
" \"specifically tailored to their preferences.\\nLarge language models naturally follow patterns in input \"\n",
" \"(prompt), and provide coherent completion that follows the same patterns. For that, we want to feed \"\n",
" 'them with several examples in the input (\"few-shot prompt\"), so they can follow through. '\n",
" \"The process of creating the correct prompt for your problem is called prompt engineering, \"\n",
" \"and you can read more about it here.\"\n",
")\n",
"\n",
"semantic_text_splitter = AI21SemanticTextSplitter()\n",
"document = Document(page_content=TEXT, metadata={\"hello\": \"goodbye\"})\n",
"documents = semantic_text_splitter.split_documents([document])\n",
"print(f\"The document list has been split into {len(documents)} Documents.\")\n",
"for doc in documents:\n",
" print(f\"text: {doc.page_content}\")\n",
" print(f\"metadata: {doc.metadata}\")\n",
" print(\"====\")"
]
},
{
"cell_type": "markdown",
"id": "f8f911b8d9ec22e5",
"metadata": {
"collapsed": false
},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -44,6 +44,7 @@ LangChain offers many different types of text splitters. These all live in the `
| Token | Tokens | | Splits text on tokens. There exist a few different ways to measure tokens. |
| Character | A user defined character | | Splits text based on a user defined character. One of the simpler methods. |
| [Experimental] Semantic Chunker | Sentences | | First splits on sentences. Then combines ones next to each other if they are semantically similar enough. Taken from [Greg Kamradt](https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/5_Levels_Of_Text_Splitting.ipynb) |
| [AI21 Semantic Text Splitter](/docs/integrations/document_transformers/ai21_semantic_text_splitter) | Semantics | ✅ | Identifies distinct topics that form coherent pieces of text and splits along those. |
## Evaluate text splitters

View File

@ -102,4 +102,20 @@ chain = tsm | StrOutputParser()
response = chain.invoke(
{"context": "Your context", "question": "Your question"},
)
```
## Text Splitters
### Semantic Text Splitter
You can use AI21's semantic text splitter to split a text into segments.
Instead of merely using punctuation and newlines to divide the text, it identifies distinct topics that will work well together and will form a coherent piece of text.
For a list for examples, see [this page](https://github.com/langchain-ai/langchain/blob/master/docs/docs/modules/data_connection/document_transformers/semantic_text_splitter.ipynb).
```python
from langchain_ai21 import AI21SemanticTextSplitter
splitter = AI21SemanticTextSplitter()
response = splitter.split_text("Your text")
```

View File

@ -2,10 +2,12 @@ from langchain_ai21.chat_models import ChatAI21
from langchain_ai21.contextual_answers import AI21ContextualAnswers
from langchain_ai21.embeddings import AI21Embeddings
from langchain_ai21.llms import AI21LLM
from langchain_ai21.semantic_text_splitter import AI21SemanticTextSplitter
__all__ = [
"AI21LLM",
"ChatAI21",
"AI21Embeddings",
"AI21ContextualAnswers",
"AI21SemanticTextSplitter",
]

View File

@ -0,0 +1,158 @@
import copy
import logging
import re
from typing import (
Any,
Iterable,
List,
Optional,
)
from ai21.models import DocumentType
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import SecretStr
from langchain_text_splitters import TextSplitter
from langchain_ai21.ai21_base import AI21Base
logger = logging.getLogger(__name__)
class AI21SemanticTextSplitter(TextSplitter):
"""Splitting text into coherent and readable units,
based on distinct topics and lines
"""
def __init__(
self,
chunk_size: int = 0,
chunk_overlap: int = 0,
client: Optional[Any] = None,
api_key: Optional[SecretStr] = None,
api_host: Optional[str] = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
**kwargs,
)
self._segmentation = AI21Base(
client=client,
api_key=api_key,
api_host=api_host,
timeout_sec=timeout_sec,
num_retries=num_retries,
).client.segmentation
def split_text(self, source: str) -> List[str]:
"""Split text into multiple components.
Args:
source: Specifies the text input for text segmentation
"""
response = self._segmentation.create(
source=source, source_type=DocumentType.TEXT
)
segments = [segment.segment_text for segment in response.segments]
if self._chunk_size > 0:
return self._merge_splits_no_seperator(segments)
return segments
def split_text_to_documents(self, source: str) -> List[Document]:
"""Split text into multiple documents.
Args:
source: Specifies the text input for text segmentation
"""
response = self._segmentation.create(
source=source, source_type=DocumentType.TEXT
)
return [
Document(
page_content=segment.segment_text,
metadata={"source_type": segment.segment_type},
)
for segment in response.segments
]
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
normalized_text = self._normalized_text(text)
index = 0
previous_chunk_len = 0
for chunk in self.split_text_to_documents(text):
# merge metadata from user (if exists) and from segmentation api
metadata = copy.deepcopy(_metadatas[i])
metadata.update(chunk.metadata)
if self._add_start_index:
# find the start index of the chunk
offset = index + previous_chunk_len - self._chunk_overlap
normalized_chunk = self._normalized_text(chunk.page_content)
index = normalized_text.find(normalized_chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(normalized_chunk)
documents.append(
Document(
page_content=chunk.page_content,
metadata=metadata,
)
)
return documents
def _normalized_text(self, string: str) -> str:
"""Use regular expression to replace sequences of '\n'"""
return re.sub(r"\s+", "", string)
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
"""This method overrides the default implementation of TextSplitter"""
return self._merge_splits_no_seperator(splits)
def _merge_splits_no_seperator(self, splits: Iterable[str]) -> List[str]:
"""Merge splits into chunks.
If the segment size is bigger than chunk_size,
it will be left as is (won't be cut to match to chunk_size).
If the segment size is smaller than chunk_size,
it will be merged with the next segment until the chunk_size is reached.
"""
chunks = []
current_chunk = ""
for split in splits:
split_len = self._length_function(split)
if split_len > self._chunk_size:
logger.warning(
f"Split of length {split_len}"
f"exceeds chunk size {self._chunk_size}."
)
if self._length_function(current_chunk) + split_len > self._chunk_size:
if current_chunk != "":
chunks.append(current_chunk)
current_chunk = ""
current_chunk += split
if current_chunk != "":
chunks.append(current_chunk)
return chunks

View File

@ -300,7 +300,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.30"
version = "0.1.33"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -324,15 +324,32 @@ extended-testing = ["jinja2 (>=3,<4)"]
type = "directory"
url = "../../core"
[[package]]
name = "langchain-text-splitters"
version = "0.0.1"
description = "LangChain text splitting utilities"
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langchain_text_splitters-0.0.1-py3-none-any.whl", hash = "sha256:f5b802f873f5ff6a8b9259ff34d53ed989666ef4e1582e6d1adb3b5520e3839a"},
{file = "langchain_text_splitters-0.0.1.tar.gz", hash = "sha256:ac459fa98799f5117ad5425a9330b21961321e30bc19a2a2f9f761ddadd62aa1"},
]
[package.dependencies]
langchain-core = ">=0.1.28,<0.2.0"
[package.extras]
extended-testing = ["lxml (>=5.1.0,<6.0.0)"]
[[package]]
name = "langsmith"
version = "0.1.10"
version = "0.1.23"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langsmith-0.1.10-py3-none-any.whl", hash = "sha256:2997a80aea60ed235d83502a7ccdc1f62ffb4dd6b3b7dd4218e8fa4de68a6725"},
{file = "langsmith-0.1.10.tar.gz", hash = "sha256:13e7e8b52e694aa4003370cefbb9e79cce3540c65dbf1517902bf7aa4dbbb653"},
{file = "langsmith-0.1.23-py3-none-any.whl", hash = "sha256:69984268b9867cb31b875965b3f86b6f56ba17dd5454d487d3a1a999bdaeea69"},
{file = "langsmith-0.1.23.tar.gz", hash = "sha256:327c66ec0de8c1bc57bfa47bbc70a29ef749e97c3e5571b9baf754d1e0644220"},
]
[package.dependencies]
@ -342,13 +359,13 @@ requests = ">=2,<3"
[[package]]
name = "marshmallow"
version = "3.21.0"
version = "3.21.1"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
optional = false
python-versions = ">=3.8"
files = [
{file = "marshmallow-3.21.0-py3-none-any.whl", hash = "sha256:e7997f83571c7fd476042c2c188e4ee8a78900ca5e74bd9c8097afa56624e9bd"},
{file = "marshmallow-3.21.0.tar.gz", hash = "sha256:20f53be28c6e374a711a16165fb22a8dc6003e3f7cda1285e3ca777b9193885b"},
{file = "marshmallow-3.21.1-py3-none-any.whl", hash = "sha256:f085493f79efb0644f270a9bf2892843142d80d7174bbbd2f3713f2a589dc633"},
{file = "marshmallow-3.21.1.tar.gz", hash = "sha256:4e65e9e0d80fc9e609574b9983cf32579f305c718afb30d7233ab818571768c3"},
]
[package.dependencies]
@ -507,13 +524,13 @@ testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pydantic"
version = "2.6.3"
version = "2.6.4"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.8"
files = [
{file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"},
{file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"},
{file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"},
{file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"},
]
[package.dependencies]
@ -689,13 +706,13 @@ watchdog = ">=2.0.0"
[[package]]
name = "python-dateutil"
version = "2.8.2"
version = "2.9.0.post0"
description = "Extensions to the standard Python datetime module"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
]
[package.dependencies]
@ -726,6 +743,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1009,4 +1027,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "3073522be06765f2acb7efea6ed1fcc49eaa05e82534d96fa914899dbbbb541f"
content-hash = "6ba91e0cf81e177c01efe980cbeedc2fe5a267599ce91c15acbcf2cd34df33dc"

View File

@ -8,6 +8,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.22"
langchain-text-splitters = "^0.0.1"
ai21 = "^2.1.2"
[tool.poetry.group.test]

View File

@ -0,0 +1,130 @@
from ai21 import AI21Client
from langchain_core.documents import Document
from langchain_ai21 import AI21SemanticTextSplitter
TEXT = (
"The original full name of the franchise is Pocket Monsters (ポケットモンスター, "
"Poketto Monsutā), which was abbreviated to "
"Pokemon during development of the original games.\n"
"When the franchise was released internationally, the short form of the title was "
"used, with an acute accent (´) "
"over the e to aid in pronunciation.\n"
"Pokémon refers to both the franchise itself and the creatures within its "
"fictional universe.\n"
"As a noun, it is identical in both the singular and plural, as is every "
"individual species name;[10] it is "
'grammatically correct to say "one Pokémon" and "many Pokémon", as well '
'as "one Pikachu" and "many Pikachu".\n'
"In English, Pokémon may be pronounced either /'powkɛmon/ (poe-keh-mon) or "
"/'powkɪmon/ (poe-key-mon).\n"
"The Pokémon franchise is set in a world in which humans coexist with creatures "
"known as Pokémon.\n"
"Pokémon Red and Blue contain 151 Pokémon species, with new ones being introduced "
"in subsequent games; as of December 2023, 1,025 Pokémon species have been "
"introduced.\n[b] Most Pokémon are inspired by real-world animals;[12] for example,"
"Pikachu are a yellow mouse-like species[13] with lightning bolt-shaped tails[14] "
"that possess electrical abilities.[15]\nThe player character takes the role of a "
"Pokémon Trainer.\nThe Trainer has three primary goals: travel and explore the "
"Pokémon world; discover and catch each Pokémon species in order to complete their"
"Pokédex; and train a team of up to six Pokémon at a time and have them engage "
"in battles.\nMost Pokémon can be caught with spherical devices known as Poké "
"Balls.\nOnce the opposing Pokémon is sufficiently weakened, the Trainer throws "
"the Poké Ball against the Pokémon, which is then transformed into a form of "
"energy and transported into the device.\nOnce the catch is successful, "
"the Pokémon is tamed and is under the Trainer's command from then on.\n"
"If the Poké Ball is thrown again, the Pokémon re-materializes into its "
"original state.\nThe Trainer's Pokémon can engage in battles against opposing "
"Pokémon, including those in the wild or owned by other Trainers.\nBecause the "
"franchise is aimed at children, these battles are never presented as overtly "
"violent and contain no blood or gore.[I]\nPokémon never die in battle, instead "
"fainting upon being defeated.[20][21][22]\nAfter a Pokémon wins a battle, it "
"gains experience and becomes stronger.[23] After gaining a certain amount of "
"experience points, its level increases, as well as one or more of its "
"statistics.\nAs its level increases, the Pokémon can learn new offensive "
"and defensive moves to use in battle.[24][25] Furthermore, many species can "
"undergo a form of spontaneous metamorphosis called Pokémon evolution, and "
"transform into stronger forms.[26] Most Pokémon will evolve at a certain level, "
"while others evolve through different means, such as exposure to a certain "
"item.[27]\n"
)
def test_split_text_to_document() -> None:
segmentation = AI21SemanticTextSplitter()
segments = segmentation.split_text_to_documents(source=TEXT)
assert len(segments) > 0
for segment in segments:
assert segment.page_content is not None
assert segment.metadata is not None
def test_split_text() -> None:
segmentation = AI21SemanticTextSplitter()
segments = segmentation.split_text(source=TEXT)
assert len(segments) > 0
def test_split_text__when_chunk_size_is_large__should_merge_segments() -> None:
segmentation_no_merge = AI21SemanticTextSplitter()
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
segmentation_merge = AI21SemanticTextSplitter(chunk_size=1000)
segments_merge = segmentation_merge.split_text(source=TEXT)
# Assert that a merge did happen
assert len(segments_no_merge) > len(segments_merge)
reconstructed_text_merged = "".join(segments_merge)
reconstructed_text_non_merged = "".join(segments_no_merge)
# Assert that the merge did not change the content
assert reconstructed_text_merged == reconstructed_text_non_merged
def test_split_text__chunk_size_is_too_small__should_return_non_merged_segments() -> (
None
):
segmentation_no_merge = AI21SemanticTextSplitter()
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
segmentation_merge = AI21SemanticTextSplitter(chunk_size=10)
segments_merge = segmentation_merge.split_text(source=TEXT)
# Assert that a merge did happen
assert len(segments_no_merge) == len(segments_merge)
reconstructed_text_merged = "".join(segments_merge)
reconstructed_text_non_merged = "".join(segments_no_merge)
# Assert that the merge did not change the content
assert reconstructed_text_merged == reconstructed_text_non_merged
def test_split_text__when_chunk_size_set_with_ai21_tokenizer() -> None:
segmentation_no_merge = AI21SemanticTextSplitter(
length_function=AI21Client().count_tokens
)
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
segmentation_merge = AI21SemanticTextSplitter(
chunk_size=1000, length_function=AI21Client().count_tokens
)
segments_merge = segmentation_merge.split_text(source=TEXT)
# Assert that a merge did happen
assert len(segments_no_merge) > len(segments_merge)
reconstructed_text_merged = "".join(segments_merge)
reconstructed_text_non_merged = "".join(segments_no_merge)
# Assert that the merge did not change the content
assert reconstructed_text_merged == reconstructed_text_non_merged
def test_create_documents() -> None:
texts = [TEXT]
segmentation = AI21SemanticTextSplitter()
documents = segmentation.create_documents(texts=texts)
assert len(documents) > 0
for document in documents:
assert document.page_content is not None
assert document.metadata is not None
def test_split_documents() -> None:
documents = [Document(page_content=TEXT, metadata={"foo": "bar"})]
segmentation = AI21SemanticTextSplitter()
segments = segmentation.split_documents(documents=documents)
assert len(segments) > 0
for segment in segments:
assert segment.page_content is not None
assert segment.metadata is not None

View File

@ -16,12 +16,13 @@ from ai21.models import (
FinishReason,
Penalty,
RoleType,
SegmentationResponse,
)
from ai21.models.responses.segmentation_response import Segment
from pytest_mock import MockerFixture
DUMMY_API_KEY = "test_api_key"
BASIC_EXAMPLE_LLM_PARAMETERS = {
"num_results": 3,
"max_tokens": 20,
@ -38,6 +39,32 @@ BASIC_EXAMPLE_LLM_PARAMETERS = {
),
}
SEGMENTS = [
Segment(
segment_type="normal_text",
segment_text=(
"The original full name of the franchise is Pocket Monsters "
"(ポケットモンスター, Poketto Monsutā), which was abbreviated to "
"Pokemon during development of the original games.\n\nWhen the "
"franchise was released internationally, the short form of the "
"title was used, with an acute accent (´) over the e to aid "
"in pronunciation."
),
),
Segment(
segment_type="normal_text",
segment_text=(
"Pokémon refers to both the franchise itself and the creatures "
"within its fictional universe.\n\nAs a noun, it is identical in "
"both the singular and plural, as is every individual species "
'name;[10] it is grammatically correct to say "one Pokémon" '
'and "many Pokémon", as well as "one Pikachu" and "many '
'Pikachu".\n\nIn English, Pokémon may be pronounced either '
"/'powkɛmon/ (poe-keh-mon) or /'powkɪmon/ (poe-key-mon)."
),
),
]
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
"num_results": 3,
@ -125,3 +152,15 @@ def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
)
return mock_client
@pytest.fixture
def mock_client_with_semantic_text_splitter(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.segmentation = mocker.MagicMock()
mock_client.segmentation.create.return_value = SegmentationResponse(
id="12345",
segments=SEGMENTS,
)
return mock_client

View File

@ -5,6 +5,7 @@ EXPECTED_ALL = [
"ChatAI21",
"AI21Embeddings",
"AI21ContextualAnswers",
"AI21SemanticTextSplitter",
]

View File

@ -0,0 +1,129 @@
from unittest.mock import Mock
import pytest
from langchain_ai21 import AI21SemanticTextSplitter
from tests.unit_tests.conftest import SEGMENTS
TEXT = (
"The original full name of the franchise is Pocket Monsters (ポケットモンスター, "
"Poketto Monsutā), which was abbreviated to "
"Pokemon during development of the original games.\n"
"When the franchise was released internationally, the short form of the title was "
"used, with an acute accent (´) "
"over the e to aid in pronunciation.\n"
"Pokémon refers to both the franchise itself and the creatures within its "
"fictional universe.\n"
"As a noun, it is identical in both the singular and plural, as is every "
"individual species name;[10] it is "
'grammatically correct to say "one Pokémon" and "many Pokémon", as well '
'as "one Pikachu" and "many Pikachu".\n'
"In English, Pokémon may be pronounced either /'powkɛmon/ (poe-keh-mon) or "
"/'powkɪmon/ (poe-key-mon).\n"
"The Pokémon franchise is set in a world in which humans coexist with creatures "
"known as Pokémon.\n"
"Pokémon Red and Blue contain 151 Pokémon species, with new ones being introduced "
"in subsequent games; as of December 2023, 1,025 Pokémon species have been "
"introduced.\n[b] Most Pokémon are inspired by real-world animals;[12] for example,"
"Pikachu are a yellow mouse-like species[13] with lightning bolt-shaped tails[14] "
"that possess electrical abilities.[15]"
)
@pytest.mark.parametrize(
ids=[
"when_chunk_size_is_zero",
"when_chunk_size_is_large",
"when_chunk_size_is_small",
],
argnames=["chunk_size", "expected_segmentation_len"],
argvalues=[
(0, 2),
(1000, 1),
(10, 2),
],
)
def test_split_text__on_chunk_size(
chunk_size: int,
expected_segmentation_len: int,
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts = AI21SemanticTextSplitter(
chunk_size=chunk_size,
client=mock_client_with_semantic_text_splitter,
)
segments = sts.split_text("This is a test")
assert len(segments) == expected_segmentation_len
def test_split_text__on_large_chunk_size__should_merge_chunks(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts_no_merge = AI21SemanticTextSplitter(
client=mock_client_with_semantic_text_splitter
)
sts_merge = AI21SemanticTextSplitter(
client=mock_client_with_semantic_text_splitter,
chunk_size=1000,
)
segments_no_merge = sts_no_merge.split_text("This is a test")
segments_merge = sts_merge.split_text("This is a test")
assert len(segments_merge) > 0
assert len(segments_no_merge) > 0
assert len(segments_no_merge) > len(segments_merge)
def test_split_text__on_small_chunk_size__should_not_merge_chunks(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts_no_merge = AI21SemanticTextSplitter(
client=mock_client_with_semantic_text_splitter
)
segments = sts_no_merge.split_text("This is a test")
assert len(segments) == 2
for index in range(2):
assert segments[index] == SEGMENTS[index].segment_text
def test_create_documents__on_start_index__should_should_add_start_index(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts = AI21SemanticTextSplitter(
client=mock_client_with_semantic_text_splitter,
add_start_index=True,
)
response = sts.create_documents(texts=[TEXT])
assert len(response) > 0
for segment in response:
assert segment.page_content is not None
assert segment.metadata is not None
assert "start_index" in segment.metadata
assert segment.metadata["start_index"] > -1
def test_create_documents__when_metadata_from_user__should_add_metadata(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts = AI21SemanticTextSplitter(client=mock_client_with_semantic_text_splitter)
metadatas = [{"hello": "world"}]
response = sts.create_documents(texts=[TEXT], metadatas=metadatas)
assert len(response) > 0
for index in range(len(response)):
assert response[index].page_content == SEGMENTS[index].segment_text
assert len(response[index].metadata) == 2
assert response[index].metadata["source_type"] == SEGMENTS[index].segment_type
assert response[index].metadata["hello"] == "world"
def test_split_text_to_documents__when_metadata_not_passed__should_contain_source_type(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts = AI21SemanticTextSplitter(client=mock_client_with_semantic_text_splitter)
response = sts.split_text_to_documents(TEXT)
assert len(response) > 0
for segment in response:
assert segment.page_content is not None
assert segment.metadata is not None
assert "source_type" in segment.metadata
assert segment.metadata["source_type"] is not None