diff --git a/docs/docs/integrations/document_loaders/source_code.ipynb b/docs/docs/integrations/document_loaders/source_code.ipynb index 043feaf75da..e05138d6b27 100644 --- a/docs/docs/integrations/document_loaders/source_code.ipynb +++ b/docs/docs/integrations/document_loaders/source_code.ipynb @@ -17,6 +17,7 @@ "- C++ (*)\n", "- C# (*)\n", "- COBOL\n", + "- Elixir\n", "- Go (*)\n", "- Java (*)\n", "- JavaScript (requires package `esprima`)\n", diff --git a/libs/community/langchain_community/document_loaders/parsers/language/elixir.py b/libs/community/langchain_community/document_loaders/parsers/language/elixir.py new file mode 100644 index 00000000000..780209767d8 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/parsers/language/elixir.py @@ -0,0 +1,35 @@ +from typing import TYPE_CHECKING + +from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501 + TreeSitterSegmenter, +) + +if TYPE_CHECKING: + from tree_sitter import Language + + +CHUNK_QUERY = """ + [ + (call target: ((identifier) @_identifier + (#any-of? @_identifier "defmodule" "defprotocol" "defimpl"))) @module + (call target: ((identifier) @_identifier + (#any-of? @_identifier "def" "defmacro" "defmacrop" "defp"))) @function + (unary_operator operator: "@" operand: (call target: ((identifier) @_identifier + (#any-of? @_identifier "moduledoc" "typedoc""doc")))) @comment + ] +""".strip() + + +class ElixirSegmenter(TreeSitterSegmenter): + """Code segmenter for Elixir.""" + + def get_language(self) -> "Language": + from tree_sitter_languages import get_language + + return get_language("elixir") + + def get_chunk_query(self) -> str: + return CHUNK_QUERY + + def make_line_comment(self, text: str) -> str: + return f"# {text}" diff --git a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py index 9405598d207..f44d74e6906 100644 --- a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py +++ b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py @@ -10,6 +10,7 @@ from langchain_community.document_loaders.parsers.language.c import CSegmenter from langchain_community.document_loaders.parsers.language.cobol import CobolSegmenter from langchain_community.document_loaders.parsers.language.cpp import CPPSegmenter from langchain_community.document_loaders.parsers.language.csharp import CSharpSegmenter +from langchain_community.document_loaders.parsers.language.elixir import ElixirSegmenter from langchain_community.document_loaders.parsers.language.go import GoSegmenter from langchain_community.document_loaders.parsers.language.java import JavaSegmenter from langchain_community.document_loaders.parsers.language.javascript import ( @@ -44,6 +45,8 @@ LANGUAGE_EXTENSIONS: Dict[str, str] = { "ts": "ts", "java": "java", "php": "php", + "ex": "elixir", + "exs": "elixir", } LANGUAGE_SEGMENTERS: Dict[str, Any] = { @@ -63,6 +66,7 @@ LANGUAGE_SEGMENTERS: Dict[str, Any] = { "ts": TypeScriptSegmenter, "java": JavaSegmenter, "php": PHPSegmenter, + "elixir": ElixirSegmenter, } Language = Literal[ @@ -89,6 +93,7 @@ Language = Literal[ "c", "lua", "perl", + "elixir", ] @@ -107,6 +112,7 @@ class LanguageParser(BaseBlobParser): - C++: "cpp" (*) - C#: "csharp" (*) - COBOL: "cobol" + - Elixir: "elixir" - Go: "go" (*) - Java: "java" (*) - JavaScript: "js" (requires package `esprima`) diff --git a/libs/community/tests/unit_tests/document_loaders/parsers/language/test_elixir.py b/libs/community/tests/unit_tests/document_loaders/parsers/language/test_elixir.py new file mode 100644 index 00000000000..02d6af92656 --- /dev/null +++ b/libs/community/tests/unit_tests/document_loaders/parsers/language/test_elixir.py @@ -0,0 +1,57 @@ +import unittest + +import pytest + +from langchain_community.document_loaders.parsers.language.elixir import ElixirSegmenter + + +@pytest.mark.requires("tree_sitter", "tree_sitter_languages") +class TestElixirSegmenter(unittest.TestCase): + def setUp(self) -> None: + self.example_code = """@doc "some comment" +def foo do + i = 0 +end + +defmodule M do + def hi do + i = 2 + end + + defp wave do + :ok + end +end""" + + self.expected_simplified_code = """# Code for: @doc "some comment" +# Code for: def foo do + +# Code for: defmodule M do""" + + self.expected_extracted_code = [ + '@doc "some comment"', + "def foo do\n i = 0\nend", + "defmodule M do\n" + " def hi do\n" + " i = 2\n" + " end\n\n" + " defp wave do\n" + " :ok\n" + " end\n" + "end", + ] + + def test_is_valid(self) -> None: + self.assertTrue(ElixirSegmenter("def a do; end").is_valid()) + self.assertFalse(ElixirSegmenter("a b c 1 2 3").is_valid()) + + def test_extract_functions_classes(self) -> None: + segmenter = ElixirSegmenter(self.example_code) + extracted_code = segmenter.extract_functions_classes() + self.assertEqual(len(extracted_code), 3) + self.assertEqual(extracted_code, self.expected_extracted_code) + + def test_simplify_code(self) -> None: + segmenter = ElixirSegmenter(self.example_code) + simplified_code = segmenter.simplify_code() + self.assertEqual(simplified_code, self.expected_simplified_code) diff --git a/libs/text-splitters/langchain_text_splitters/base.py b/libs/text-splitters/langchain_text_splitters/base.py index bdf7ae7be2d..36de4bca09f 100644 --- a/libs/text-splitters/langchain_text_splitters/base.py +++ b/libs/text-splitters/langchain_text_splitters/base.py @@ -293,6 +293,7 @@ class Language(str, Enum): LUA = "lua" PERL = "perl" HASKELL = "haskell" + ELIXIR = "elixir" @dataclass(frozen=True) diff --git a/libs/text-splitters/langchain_text_splitters/character.py b/libs/text-splitters/langchain_text_splitters/character.py index a492bb01b38..14cccc3c664 100644 --- a/libs/text-splitters/langchain_text_splitters/character.py +++ b/libs/text-splitters/langchain_text_splitters/character.py @@ -343,6 +343,30 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] + elif language == Language.ELIXIR: + return [ + # Split along method function and module definiton + "\ndef ", + "\ndefp ", + "\ndefmodule ", + "\ndefprotocol ", + "\ndefmacro ", + "\ndefmacrop ", + # Split along control flow statements + "\nif ", + "\nunless ", + "\nwhile ", + "\ncase ", + "\ncond ", + "\nwith ", + "\nfor ", + "\ndo ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] elif language == Language.RUST: return [ # Split along function definitions