mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-29 06:23:20 +00:00
updated code parser to use tree sitter lang pack
This commit is contained in:
parent
d7d0bca2bc
commit
6f790d5c22
@ -30,7 +30,7 @@
|
||||
"- Scala (*)\n",
|
||||
"- TypeScript (*)\n",
|
||||
"\n",
|
||||
"Items marked with (*) require the packages `tree_sitter` and `tree_sitter_languages`.\n",
|
||||
"Items marked with (*) require the packages `tree_sitter` and `tree-sitter-language-pack`.\n",
|
||||
"It is straightforward to add support for additional languages using `tree_sitter`,\n",
|
||||
"although this currently requires modifying LangChain.\n",
|
||||
"\n",
|
||||
@ -47,9 +47,7 @@
|
||||
"id": "7fa47b2e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU esprima esprima tree_sitter tree_sitter_languages"
|
||||
]
|
||||
"source": "%pip install -qU esprima esprima tree_sitter tree-sitter-language-pack"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -91,8 +91,8 @@ tidb-vector>=0.0.3,<1.0.0
|
||||
timescale-vector==0.0.1
|
||||
tqdm>=4.48.0
|
||||
tiktoken>=0.8.0
|
||||
tree-sitter>=0.20.2,<0.21
|
||||
tree-sitter-languages>=1.8.0,<2
|
||||
tree-sitter>=0.23.2,<1
|
||||
tree-sitter-language-pack>=0.6.1,<1
|
||||
upstash-redis>=1.1.0,<2
|
||||
upstash-ratelimit>=1.1.0,<2
|
||||
vdms>=0.0.20
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -25,10 +25,15 @@ class CSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for C."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("c")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("c")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -25,10 +25,15 @@ class CPPSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for C++."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("cpp")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("cpp")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -25,9 +25,14 @@ class CSharpSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for C#."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("c_sharp")
|
||||
return get_language("csharp")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("csharp")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
@ -5,17 +5,49 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
CHUNK_QUERY = """
|
||||
[
|
||||
(call target: ((identifier) @_identifier
|
||||
(#any-of? @_identifier "defmodule" "defprotocol" "defimpl"))) @module
|
||||
(call target: ((identifier) @_identifier
|
||||
(#any-of? @_identifier "def" "defmacro" "defmacrop" "defp"))) @function
|
||||
(unary_operator operator: "@" operand: (call target: ((identifier) @_identifier
|
||||
(#any-of? @_identifier "moduledoc" "typedoc""doc")))) @comment
|
||||
(unary_operator
|
||||
operator: "@"
|
||||
operand: (call
|
||||
target: (identifier)
|
||||
(arguments
|
||||
[
|
||||
(string)
|
||||
(charlist)
|
||||
(sigil
|
||||
quoted_start: _
|
||||
quoted_end: _
|
||||
)
|
||||
(boolean)
|
||||
]
|
||||
)
|
||||
)
|
||||
) @comment
|
||||
|
||||
(call
|
||||
target: (identifier)
|
||||
(arguments (alias))
|
||||
) @module
|
||||
|
||||
(call
|
||||
target: (identifier)
|
||||
(arguments
|
||||
[
|
||||
; zero-arity functions with no parentheses
|
||||
(identifier)
|
||||
; regular function clause
|
||||
(call target: (identifier))
|
||||
; function clause with a guard clause
|
||||
(binary_operator
|
||||
left: (call target: (identifier))
|
||||
operator: "when"
|
||||
)
|
||||
]
|
||||
)
|
||||
) @function
|
||||
]
|
||||
""".strip()
|
||||
|
||||
@ -24,10 +56,15 @@ class ElixirSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Elixir."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("elixir")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("elixir")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -20,10 +20,15 @@ class GoSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Go."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("go")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("go")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -21,10 +21,15 @@ class JavaSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Java."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("java")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("java")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -20,10 +20,15 @@ class KotlinSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Kotlin."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("kotlin")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("kotlin")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -130,7 +130,7 @@ class LanguageParser(BaseBlobParser):
|
||||
- TypeScript: "ts" (*)
|
||||
|
||||
Items marked with (*) require the packages `tree_sitter` and
|
||||
`tree_sitter_languages`. It is straightforward to add support for additional
|
||||
`tree-sitter-language-pack`. It is straightforward to add support for additional
|
||||
languages using `tree_sitter`, although this currently requires modifying LangChain.
|
||||
|
||||
The language used for parsing can be configured, along with the minimum number of
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -22,10 +22,15 @@ class LuaSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Lua."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("lua")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("lua")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,12 +5,12 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
[
|
||||
(function_definition) @subroutine
|
||||
(subroutine_declaration_statement) @subroutine
|
||||
]
|
||||
""".strip()
|
||||
|
||||
@ -19,10 +19,15 @@ class PerlSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Perl."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("perl")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("perl")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -24,10 +24,15 @@ class PHPSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for PHP."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("php")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("php")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -21,10 +21,15 @@ class RubySegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Ruby."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("ruby")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("ruby")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -23,10 +23,15 @@ class RustSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Rust."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("rust")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("rust")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -22,10 +22,15 @@ class ScalaSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for Scala."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("scala")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("scala")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,15 +5,20 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
# CHUNK_QUERY = """
|
||||
# [
|
||||
# (create_table) @create
|
||||
# (_select_statement) @select
|
||||
# (insert) @insert
|
||||
# (update) @update
|
||||
# (_delete_statement) @delete
|
||||
# ]
|
||||
# """
|
||||
CHUNK_QUERY = """
|
||||
[
|
||||
(create_table_statement) @create
|
||||
(select_statement) @select
|
||||
(insert_statement) @insert
|
||||
(update_statement) @update
|
||||
(delete_statement) @delete
|
||||
(statement) @statement
|
||||
]
|
||||
"""
|
||||
|
||||
@ -28,10 +33,15 @@ class SQLSegmenter(TreeSitterSegmenter):
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
"""Return the SQL language grammar for Tree-sitter."""
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("sql")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("sql")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
"""Return the Tree-sitter query for SQL segmentation."""
|
||||
return CHUNK_QUERY
|
||||
|
@ -6,7 +6,7 @@ from langchain_community.document_loaders.parsers.language.code_segmenter import
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language, Parser
|
||||
from tree_sitter import Language, Node, Parser
|
||||
|
||||
|
||||
class TreeSitterSegmenter(CodeSegmenter):
|
||||
@ -18,12 +18,12 @@ class TreeSitterSegmenter(CodeSegmenter):
|
||||
|
||||
try:
|
||||
import tree_sitter # noqa: F401
|
||||
import tree_sitter_languages # noqa: F401
|
||||
import tree_sitter_language_pack # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tree_sitter/tree_sitter_languages Python packages. "
|
||||
"Please install them with "
|
||||
"`pip install tree-sitter tree-sitter-languages`."
|
||||
"Could not import tree_sitter/tree_sitter_language_pack "
|
||||
"Python packages. Please install them with "
|
||||
"`pip install tree-sitter tree-sitter-language-pack`."
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
@ -35,48 +35,50 @@ class TreeSitterSegmenter(CodeSegmenter):
|
||||
|
||||
return len(error_query.captures(tree.root_node)) == 0
|
||||
|
||||
def extract_functions_classes(self) -> List[str]:
|
||||
def _get_top_level_nodes(self) -> List["Node"]:
|
||||
language = self.get_language()
|
||||
query = language.query(self.get_chunk_query())
|
||||
|
||||
parser = self.get_parser()
|
||||
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
|
||||
captures = query.captures(tree.root_node)
|
||||
|
||||
processed_lines = set()
|
||||
chunks = []
|
||||
|
||||
for node, name in captures:
|
||||
start_line = node.start_point[0]
|
||||
end_line = node.end_point[0]
|
||||
lines = list(range(start_line, end_line + 1))
|
||||
|
||||
if any(line in processed_lines for line in lines):
|
||||
top_level_nodes = {}
|
||||
for node_type, nodes in captures.items():
|
||||
for node in nodes:
|
||||
cursor = node.parent
|
||||
is_child = False
|
||||
while cursor is not None:
|
||||
if cursor.id in top_level_nodes:
|
||||
is_child = True
|
||||
break
|
||||
cursor = cursor.parent
|
||||
if is_child:
|
||||
continue
|
||||
top_level_nodes[node.id] = node
|
||||
|
||||
processed_lines.update(lines)
|
||||
chunk_text = node.text.decode("UTF-8")
|
||||
chunks.append(chunk_text)
|
||||
children = node.children
|
||||
for child in children:
|
||||
if child.id in top_level_nodes:
|
||||
del top_level_nodes[child.id]
|
||||
children.extend(child.children)
|
||||
top_level_nodes_list = list(top_level_nodes.values())
|
||||
top_level_nodes_list.sort(key=lambda n: n.start_point[0])
|
||||
return top_level_nodes_list
|
||||
|
||||
return chunks
|
||||
def extract_functions_classes(self) -> List[str]:
|
||||
top_level_nodes = self._get_top_level_nodes()
|
||||
return [
|
||||
node.text.decode("UTF-8")
|
||||
for node in top_level_nodes
|
||||
if node.text is not None
|
||||
]
|
||||
|
||||
def simplify_code(self) -> str:
|
||||
language = self.get_language()
|
||||
query = language.query(self.get_chunk_query())
|
||||
|
||||
parser = self.get_parser()
|
||||
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
|
||||
processed_lines = set()
|
||||
|
||||
simplified_lines = self.source_lines[:]
|
||||
for node, name in query.captures(tree.root_node):
|
||||
top_level_nodes = self._get_top_level_nodes()
|
||||
for node in top_level_nodes:
|
||||
start_line = node.start_point[0]
|
||||
end_line = node.end_point[0]
|
||||
|
||||
lines = list(range(start_line, end_line + 1))
|
||||
if any(line in processed_lines for line in lines):
|
||||
continue
|
||||
|
||||
simplified_lines[start_line] = self.make_line_comment(
|
||||
f"Code for: {self.source_lines[start_line]}"
|
||||
)
|
||||
@ -84,16 +86,11 @@ class TreeSitterSegmenter(CodeSegmenter):
|
||||
for line_num in range(start_line + 1, end_line + 1):
|
||||
simplified_lines[line_num] = None # type: ignore
|
||||
|
||||
processed_lines.update(lines)
|
||||
|
||||
return "\n".join(line for line in simplified_lines if line is not None)
|
||||
|
||||
@abstractmethod
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter import Parser
|
||||
|
||||
parser = Parser()
|
||||
parser.set_language(self.get_language())
|
||||
return parser
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_language(self) -> "Language":
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Language
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
CHUNK_QUERY = """
|
||||
@ -22,10 +22,15 @@ class TypeScriptSegmenter(TreeSitterSegmenter):
|
||||
"""Code segmenter for TypeScript."""
|
||||
|
||||
def get_language(self) -> "Language":
|
||||
from tree_sitter_languages import get_language
|
||||
from tree_sitter_language_pack import get_language
|
||||
|
||||
return get_language("typescript")
|
||||
|
||||
def get_parser(self) -> "Parser":
|
||||
from tree_sitter_language_pack import get_parser
|
||||
|
||||
return get_parser("typescript")
|
||||
|
||||
def get_chunk_query(self) -> str:
|
||||
return CHUNK_QUERY
|
||||
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.c import CSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestCSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """int main() {
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.cpp import CPPSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestCPPSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """int foo() {
|
||||
@ -55,9 +55,9 @@ auto T::bar() const -> int {
|
||||
def test_extract_functions_classes(self) -> None:
|
||||
segmenter = CPPSegmenter(self.example_code)
|
||||
extracted_code = segmenter.extract_functions_classes()
|
||||
self.assertEqual(extracted_code, self.expected_extracted_code)
|
||||
self.assertEqual(self.expected_extracted_code, extracted_code)
|
||||
|
||||
def test_simplify_code(self) -> None:
|
||||
segmenter = CPPSegmenter(self.example_code)
|
||||
simplified_code = segmenter.simplify_code()
|
||||
self.assertEqual(simplified_code, self.expected_simplified_code)
|
||||
self.assertEqual(self.expected_simplified_code, simplified_code)
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.csharp import CSharpSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestCSharpSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """namespace World
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.elixir import ElixirSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestElixirSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """@doc "some comment"
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.go import GoSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestGoSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """func foo(a int) int {
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.java import JavaSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestJavaSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """class Hello
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.kotlin import KotlinSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestKotlinSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """fun foo(a: Int): Int {
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.lua import LuaSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestLuaSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """function F()
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.perl import PerlSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestPerlSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """sub Hello {
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.php import PHPSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestPHPSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """<?php
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.ruby import RubySegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestRubySegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """def foo
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.rust import RustSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestRustSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """fn foo() -> i32 {
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.scala import ScalaSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestScalaSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """def foo() {
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestSQLSegmenter(unittest.TestCase):
|
||||
"""Unit tests for the SQLSegmenter class."""
|
||||
|
||||
|
@ -7,7 +7,7 @@ from langchain_community.document_loaders.parsers.language.typescript import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
||||
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||
class TestTypeScriptSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """function foo(): number
|
||||
|
Loading…
Reference in New Issue
Block a user