updated code parser to use tree sitter lang pack

This commit is contained in:
Lochlann Andrews 2025-03-27 17:46:12 +10:00
parent d7d0bca2bc
commit 6f790d5c22
34 changed files with 216 additions and 109 deletions

View File

@ -30,7 +30,7 @@
"- Scala (*)\n",
"- TypeScript (*)\n",
"\n",
"Items marked with (*) require the packages `tree_sitter` and `tree_sitter_languages`.\n",
"Items marked with (*) require the packages `tree_sitter` and `tree-sitter-language-pack`.\n",
"It is straightforward to add support for additional languages using `tree_sitter`,\n",
"although this currently requires modifying LangChain.\n",
"\n",
@ -47,9 +47,7 @@
"id": "7fa47b2e",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU esprima esprima tree_sitter tree_sitter_languages"
]
"source": "%pip install -qU esprima esprima tree_sitter tree-sitter-language-pack"
},
{
"cell_type": "code",

View File

@ -91,8 +91,8 @@ tidb-vector>=0.0.3,<1.0.0
timescale-vector==0.0.1
tqdm>=4.48.0
tiktoken>=0.8.0
tree-sitter>=0.20.2,<0.21
tree-sitter-languages>=1.8.0,<2
tree-sitter>=0.23.2,<1
tree-sitter-language-pack>=0.6.1,<1
upstash-redis>=1.1.0,<2
upstash-ratelimit>=1.1.0,<2
vdms>=0.0.20

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -25,10 +25,15 @@ class CSegmenter(TreeSitterSegmenter):
"""Code segmenter for C."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("c")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("c")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -25,10 +25,15 @@ class CPPSegmenter(TreeSitterSegmenter):
"""Code segmenter for C++."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("cpp")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("cpp")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -25,9 +25,14 @@ class CSharpSegmenter(TreeSitterSegmenter):
"""Code segmenter for C#."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("c_sharp")
return get_language("csharp")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("csharp")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,17 +5,49 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
[
(call target: ((identifier) @_identifier
(#any-of? @_identifier "defmodule" "defprotocol" "defimpl"))) @module
(call target: ((identifier) @_identifier
(#any-of? @_identifier "def" "defmacro" "defmacrop" "defp"))) @function
(unary_operator operator: "@" operand: (call target: ((identifier) @_identifier
(#any-of? @_identifier "moduledoc" "typedoc""doc")))) @comment
(unary_operator
operator: "@"
operand: (call
target: (identifier)
(arguments
[
(string)
(charlist)
(sigil
quoted_start: _
quoted_end: _
)
(boolean)
]
)
)
) @comment
(call
target: (identifier)
(arguments (alias))
) @module
(call
target: (identifier)
(arguments
[
; zero-arity functions with no parentheses
(identifier)
; regular function clause
(call target: (identifier))
; function clause with a guard clause
(binary_operator
left: (call target: (identifier))
operator: "when"
)
]
)
) @function
]
""".strip()
@ -24,10 +56,15 @@ class ElixirSegmenter(TreeSitterSegmenter):
"""Code segmenter for Elixir."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("elixir")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("elixir")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -20,10 +20,15 @@ class GoSegmenter(TreeSitterSegmenter):
"""Code segmenter for Go."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("go")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("go")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -21,10 +21,15 @@ class JavaSegmenter(TreeSitterSegmenter):
"""Code segmenter for Java."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("java")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("java")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -20,10 +20,15 @@ class KotlinSegmenter(TreeSitterSegmenter):
"""Code segmenter for Kotlin."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("kotlin")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("kotlin")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -130,7 +130,7 @@ class LanguageParser(BaseBlobParser):
- TypeScript: "ts" (*)
Items marked with (*) require the packages `tree_sitter` and
`tree_sitter_languages`. It is straightforward to add support for additional
`tree-sitter-language-pack`. It is straightforward to add support for additional
languages using `tree_sitter`, although this currently requires modifying LangChain.
The language used for parsing can be configured, along with the minimum number of

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -22,10 +22,15 @@ class LuaSegmenter(TreeSitterSegmenter):
"""Code segmenter for Lua."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("lua")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("lua")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,12 +5,12 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
[
(function_definition) @subroutine
(subroutine_declaration_statement) @subroutine
]
""".strip()
@ -19,10 +19,15 @@ class PerlSegmenter(TreeSitterSegmenter):
"""Code segmenter for Perl."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("perl")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("perl")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -24,10 +24,15 @@ class PHPSegmenter(TreeSitterSegmenter):
"""Code segmenter for PHP."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("php")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("php")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -21,10 +21,15 @@ class RubySegmenter(TreeSitterSegmenter):
"""Code segmenter for Ruby."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("ruby")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("ruby")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -23,10 +23,15 @@ class RustSegmenter(TreeSitterSegmenter):
"""Code segmenter for Rust."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("rust")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("rust")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -22,10 +22,15 @@ class ScalaSegmenter(TreeSitterSegmenter):
"""Code segmenter for Scala."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("scala")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("scala")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,15 +5,20 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
# CHUNK_QUERY = """
# [
# (create_table) @create
# (_select_statement) @select
# (insert) @insert
# (update) @update
# (_delete_statement) @delete
# ]
# """
CHUNK_QUERY = """
[
(create_table_statement) @create
(select_statement) @select
(insert_statement) @insert
(update_statement) @update
(delete_statement) @delete
(statement) @statement
]
"""
@ -28,10 +33,15 @@ class SQLSegmenter(TreeSitterSegmenter):
def get_language(self) -> "Language":
"""Return the SQL language grammar for Tree-sitter."""
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("sql")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("sql")
def get_chunk_query(self) -> str:
"""Return the Tree-sitter query for SQL segmentation."""
return CHUNK_QUERY

View File

@ -6,7 +6,7 @@ from langchain_community.document_loaders.parsers.language.code_segmenter import
)
if TYPE_CHECKING:
from tree_sitter import Language, Parser
from tree_sitter import Language, Node, Parser
class TreeSitterSegmenter(CodeSegmenter):
@ -18,12 +18,12 @@ class TreeSitterSegmenter(CodeSegmenter):
try:
import tree_sitter # noqa: F401
import tree_sitter_languages # noqa: F401
import tree_sitter_language_pack # noqa: F401
except ImportError:
raise ImportError(
"Could not import tree_sitter/tree_sitter_languages Python packages. "
"Please install them with "
"`pip install tree-sitter tree-sitter-languages`."
"Could not import tree_sitter/tree_sitter_language_pack "
"Python packages. Please install them with "
"`pip install tree-sitter tree-sitter-language-pack`."
)
def is_valid(self) -> bool:
@ -35,48 +35,50 @@ class TreeSitterSegmenter(CodeSegmenter):
return len(error_query.captures(tree.root_node)) == 0
def extract_functions_classes(self) -> List[str]:
def _get_top_level_nodes(self) -> List["Node"]:
language = self.get_language()
query = language.query(self.get_chunk_query())
parser = self.get_parser()
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
captures = query.captures(tree.root_node)
processed_lines = set()
chunks = []
for node, name in captures:
start_line = node.start_point[0]
end_line = node.end_point[0]
lines = list(range(start_line, end_line + 1))
if any(line in processed_lines for line in lines):
top_level_nodes = {}
for node_type, nodes in captures.items():
for node in nodes:
cursor = node.parent
is_child = False
while cursor is not None:
if cursor.id in top_level_nodes:
is_child = True
break
cursor = cursor.parent
if is_child:
continue
top_level_nodes[node.id] = node
processed_lines.update(lines)
chunk_text = node.text.decode("UTF-8")
chunks.append(chunk_text)
children = node.children
for child in children:
if child.id in top_level_nodes:
del top_level_nodes[child.id]
children.extend(child.children)
top_level_nodes_list = list(top_level_nodes.values())
top_level_nodes_list.sort(key=lambda n: n.start_point[0])
return top_level_nodes_list
return chunks
def extract_functions_classes(self) -> List[str]:
top_level_nodes = self._get_top_level_nodes()
return [
node.text.decode("UTF-8")
for node in top_level_nodes
if node.text is not None
]
def simplify_code(self) -> str:
language = self.get_language()
query = language.query(self.get_chunk_query())
parser = self.get_parser()
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
processed_lines = set()
simplified_lines = self.source_lines[:]
for node, name in query.captures(tree.root_node):
top_level_nodes = self._get_top_level_nodes()
for node in top_level_nodes:
start_line = node.start_point[0]
end_line = node.end_point[0]
lines = list(range(start_line, end_line + 1))
if any(line in processed_lines for line in lines):
continue
simplified_lines[start_line] = self.make_line_comment(
f"Code for: {self.source_lines[start_line]}"
)
@ -84,16 +86,11 @@ class TreeSitterSegmenter(CodeSegmenter):
for line_num in range(start_line + 1, end_line + 1):
simplified_lines[line_num] = None # type: ignore
processed_lines.update(lines)
return "\n".join(line for line in simplified_lines if line is not None)
@abstractmethod
def get_parser(self) -> "Parser":
from tree_sitter import Parser
parser = Parser()
parser.set_language(self.get_language())
return parser
raise NotImplementedError()
@abstractmethod
def get_language(self) -> "Language":

View File

@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
)
if TYPE_CHECKING:
from tree_sitter import Language
from tree_sitter import Language, Parser
CHUNK_QUERY = """
@ -22,10 +22,15 @@ class TypeScriptSegmenter(TreeSitterSegmenter):
"""Code segmenter for TypeScript."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language
return get_language("typescript")
def get_parser(self) -> "Parser":
from tree_sitter_language_pack import get_parser
return get_parser("typescript")
def get_chunk_query(self) -> str:
return CHUNK_QUERY

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.c import CSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestCSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """int main() {

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.cpp import CPPSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestCPPSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """int foo() {
@ -55,9 +55,9 @@ auto T::bar() const -> int {
def test_extract_functions_classes(self) -> None:
segmenter = CPPSegmenter(self.example_code)
extracted_code = segmenter.extract_functions_classes()
self.assertEqual(extracted_code, self.expected_extracted_code)
self.assertEqual(self.expected_extracted_code, extracted_code)
def test_simplify_code(self) -> None:
segmenter = CPPSegmenter(self.example_code)
simplified_code = segmenter.simplify_code()
self.assertEqual(simplified_code, self.expected_simplified_code)
self.assertEqual(self.expected_simplified_code, simplified_code)

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.csharp import CSharpSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestCSharpSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """namespace World

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.elixir import ElixirSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestElixirSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """@doc "some comment"

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.go import GoSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestGoSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """func foo(a int) int {

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.java import JavaSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestJavaSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """class Hello

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.kotlin import KotlinSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestKotlinSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """fun foo(a: Int): Int {

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.lua import LuaSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestLuaSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """function F()

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.perl import PerlSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestPerlSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """sub Hello {

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.php import PHPSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestPHPSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """<?php

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.ruby import RubySegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestRubySegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """def foo

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.rust import RustSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestRustSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """fn foo() -> i32 {

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.scala import ScalaSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestScalaSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """def foo() {

View File

@ -5,7 +5,7 @@ import pytest
from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestSQLSegmenter(unittest.TestCase):
"""Unit tests for the SQLSegmenter class."""

View File

@ -7,7 +7,7 @@ from langchain_community.document_loaders.parsers.language.typescript import (
)
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
class TestTypeScriptSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """function foo(): number