langchain/libs/text-splitters/langchain_text_splitters/base.py
Max Mulatz 058a64c563
Community[minor]: Add language parser for Elixir (#22742)
Hi 👋 

First off, thanks a ton for your work on this 💚 Really appreciate what
you're providing here for the community.

## Description

This PR adds a basic language parser for the
[Elixir](https://elixir-lang.org/) programming language. The parser code
is based upon the approach outlined in
https://github.com/langchain-ai/langchain/pull/13318: it's using
`tree-sitter` under the hood and aligns with all the other `tree-sitter`
based parses added that PR.

The `CHUNK_QUERY` I'm using here is probably not the most sophisticated
one, but it worked for my application. It's a starting point to provide
"core" parsing support for Elixir in LangChain. It enables people to use
the language parser out in real world applications which may then lead
to further tweaking of the queries. I consider this PR just the ground
work.

- **Dependencies:** requires `tree-sitter` and `tree-sitter-languages`
from the extended dependencies
- **Twitter handle:**`@bitcrowd`

## Checklist

- [x] **PR title**: "package: description"
- [x] **Add tests and docs**
- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified.

<!-- If no one reviews your PR within a few days, please @-mention one
of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. -->
2024-06-10 15:56:57 +00:00

328 lines
11 KiB
Python

from __future__ import annotations
import copy
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import (
AbstractSet,
Any,
Callable,
Collection,
Iterable,
List,
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
from langchain_core.documents import BaseDocumentTransformer, Document
logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: Union[bool, Literal["start", "end"]] = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
) -> None:
"""Create a new TextSplitter.
Args:
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
length_function: Function that measures the length of given chunks
keep_separator: Whether to keep the separator and where to place it
in each corresponding chunk (True='start')
add_start_index: If `True`, includes chunk's start index in metadata
strip_whitespace: If `True`, strips whitespace from the start and end of
every document
"""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""Split text into multiple components."""
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self._add_start_index:
offset = index + previous_chunk_len - self._chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_doc = Document(page_content=chunk, metadata=metadata)
documents.append(new_doc)
return documents
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
"""Split documents."""
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
metadatas.append(doc.metadata)
return self.create_documents(texts, metadatas=metadatas)
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
text = separator.join(docs)
if self._strip_whitespace:
text = text.strip()
if text == "":
return None
else:
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
):
if total > self._chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs
@classmethod
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError(
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
)
def _huggingface_tokenizer_length(text: str) -> int:
return len(tokenizer.encode(text))
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
)
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
@classmethod
def from_tiktoken_encoder(
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str) -> int:
return len(
enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_tiktoken_encoder, **kwargs)
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
class TokenTextSplitter(TextSplitter):
"""Splitting text to tokens using model tokenizer."""
def __init__(
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
def _encode(_text: str) -> List[int]:
return self._tokenizer.encode(
_text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=_encode,
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
class Language(str, Enum):
"""Enum of the programming languages."""
CPP = "cpp"
GO = "go"
JAVA = "java"
KOTLIN = "kotlin"
JS = "js"
TS = "ts"
PHP = "php"
PROTO = "proto"
PYTHON = "python"
RST = "rst"
RUBY = "ruby"
RUST = "rust"
SCALA = "scala"
SWIFT = "swift"
MARKDOWN = "markdown"
LATEX = "latex"
HTML = "html"
SOL = "sol"
CSHARP = "csharp"
COBOL = "cobol"
C = "c"
LUA = "lua"
PERL = "perl"
HASKELL = "haskell"
ELIXIR = "elixir"
@dataclass(frozen=True)
class Tokenizer:
"""Tokenizer data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
tokens_per_chunk: int
"""Maximum number of tokens per chunk"""
decode: Callable[[List[int]], str]
""" Function to decode a list of token ids to a string"""
encode: Callable[[str], List[int]]
""" Function to encode a string to a list of token ids"""
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
"""Split incoming text and return chunks using tokenizer."""
splits: List[str] = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(tokenizer.decode(chunk_ids))
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits