mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-29 14:37:21 +00:00
updated code parser to use tree sitter lang pack
This commit is contained in:
parent
d7d0bca2bc
commit
6f790d5c22
@ -30,7 +30,7 @@
|
|||||||
"- Scala (*)\n",
|
"- Scala (*)\n",
|
||||||
"- TypeScript (*)\n",
|
"- TypeScript (*)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Items marked with (*) require the packages `tree_sitter` and `tree_sitter_languages`.\n",
|
"Items marked with (*) require the packages `tree_sitter` and `tree-sitter-language-pack`.\n",
|
||||||
"It is straightforward to add support for additional languages using `tree_sitter`,\n",
|
"It is straightforward to add support for additional languages using `tree_sitter`,\n",
|
||||||
"although this currently requires modifying LangChain.\n",
|
"although this currently requires modifying LangChain.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -47,9 +47,7 @@
|
|||||||
"id": "7fa47b2e",
|
"id": "7fa47b2e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": "%pip install -qU esprima esprima tree_sitter tree-sitter-language-pack"
|
||||||
"%pip install -qU esprima esprima tree_sitter tree_sitter_languages"
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
@ -91,8 +91,8 @@ tidb-vector>=0.0.3,<1.0.0
|
|||||||
timescale-vector==0.0.1
|
timescale-vector==0.0.1
|
||||||
tqdm>=4.48.0
|
tqdm>=4.48.0
|
||||||
tiktoken>=0.8.0
|
tiktoken>=0.8.0
|
||||||
tree-sitter>=0.20.2,<0.21
|
tree-sitter>=0.23.2,<1
|
||||||
tree-sitter-languages>=1.8.0,<2
|
tree-sitter-language-pack>=0.6.1,<1
|
||||||
upstash-redis>=1.1.0,<2
|
upstash-redis>=1.1.0,<2
|
||||||
upstash-ratelimit>=1.1.0,<2
|
upstash-ratelimit>=1.1.0,<2
|
||||||
vdms>=0.0.20
|
vdms>=0.0.20
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -25,10 +25,15 @@ class CSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for C."""
|
"""Code segmenter for C."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("c")
|
return get_language("c")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("c")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -25,10 +25,15 @@ class CPPSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for C++."""
|
"""Code segmenter for C++."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("cpp")
|
return get_language("cpp")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("cpp")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -25,9 +25,14 @@ class CSharpSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for C#."""
|
"""Code segmenter for C#."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("c_sharp")
|
return get_language("csharp")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("csharp")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
@ -5,17 +5,49 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
[
|
[
|
||||||
(call target: ((identifier) @_identifier
|
(unary_operator
|
||||||
(#any-of? @_identifier "defmodule" "defprotocol" "defimpl"))) @module
|
operator: "@"
|
||||||
(call target: ((identifier) @_identifier
|
operand: (call
|
||||||
(#any-of? @_identifier "def" "defmacro" "defmacrop" "defp"))) @function
|
target: (identifier)
|
||||||
(unary_operator operator: "@" operand: (call target: ((identifier) @_identifier
|
(arguments
|
||||||
(#any-of? @_identifier "moduledoc" "typedoc""doc")))) @comment
|
[
|
||||||
|
(string)
|
||||||
|
(charlist)
|
||||||
|
(sigil
|
||||||
|
quoted_start: _
|
||||||
|
quoted_end: _
|
||||||
|
)
|
||||||
|
(boolean)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) @comment
|
||||||
|
|
||||||
|
(call
|
||||||
|
target: (identifier)
|
||||||
|
(arguments (alias))
|
||||||
|
) @module
|
||||||
|
|
||||||
|
(call
|
||||||
|
target: (identifier)
|
||||||
|
(arguments
|
||||||
|
[
|
||||||
|
; zero-arity functions with no parentheses
|
||||||
|
(identifier)
|
||||||
|
; regular function clause
|
||||||
|
(call target: (identifier))
|
||||||
|
; function clause with a guard clause
|
||||||
|
(binary_operator
|
||||||
|
left: (call target: (identifier))
|
||||||
|
operator: "when"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
) @function
|
||||||
]
|
]
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
@ -24,10 +56,15 @@ class ElixirSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Elixir."""
|
"""Code segmenter for Elixir."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("elixir")
|
return get_language("elixir")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("elixir")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -20,10 +20,15 @@ class GoSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Go."""
|
"""Code segmenter for Go."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("go")
|
return get_language("go")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("go")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -21,10 +21,15 @@ class JavaSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Java."""
|
"""Code segmenter for Java."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("java")
|
return get_language("java")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("java")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -20,10 +20,15 @@ class KotlinSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Kotlin."""
|
"""Code segmenter for Kotlin."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("kotlin")
|
return get_language("kotlin")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("kotlin")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ class LanguageParser(BaseBlobParser):
|
|||||||
- TypeScript: "ts" (*)
|
- TypeScript: "ts" (*)
|
||||||
|
|
||||||
Items marked with (*) require the packages `tree_sitter` and
|
Items marked with (*) require the packages `tree_sitter` and
|
||||||
`tree_sitter_languages`. It is straightforward to add support for additional
|
`tree-sitter-language-pack`. It is straightforward to add support for additional
|
||||||
languages using `tree_sitter`, although this currently requires modifying LangChain.
|
languages using `tree_sitter`, although this currently requires modifying LangChain.
|
||||||
|
|
||||||
The language used for parsing can be configured, along with the minimum number of
|
The language used for parsing can be configured, along with the minimum number of
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -22,10 +22,15 @@ class LuaSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Lua."""
|
"""Code segmenter for Lua."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("lua")
|
return get_language("lua")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("lua")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,12 +5,12 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
[
|
[
|
||||||
(function_definition) @subroutine
|
(subroutine_declaration_statement) @subroutine
|
||||||
]
|
]
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
@ -19,10 +19,15 @@ class PerlSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Perl."""
|
"""Code segmenter for Perl."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("perl")
|
return get_language("perl")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("perl")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -24,10 +24,15 @@ class PHPSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for PHP."""
|
"""Code segmenter for PHP."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("php")
|
return get_language("php")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("php")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -21,10 +21,15 @@ class RubySegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Ruby."""
|
"""Code segmenter for Ruby."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("ruby")
|
return get_language("ruby")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("ruby")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -23,10 +23,15 @@ class RustSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Rust."""
|
"""Code segmenter for Rust."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("rust")
|
return get_language("rust")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("rust")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -22,10 +22,15 @@ class ScalaSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for Scala."""
|
"""Code segmenter for Scala."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("scala")
|
return get_language("scala")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("scala")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,15 +5,20 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
# CHUNK_QUERY = """
|
||||||
|
# [
|
||||||
|
# (create_table) @create
|
||||||
|
# (_select_statement) @select
|
||||||
|
# (insert) @insert
|
||||||
|
# (update) @update
|
||||||
|
# (_delete_statement) @delete
|
||||||
|
# ]
|
||||||
|
# """
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
[
|
[
|
||||||
(create_table_statement) @create
|
(statement) @statement
|
||||||
(select_statement) @select
|
|
||||||
(insert_statement) @insert
|
|
||||||
(update_statement) @update
|
|
||||||
(delete_statement) @delete
|
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -28,10 +33,15 @@ class SQLSegmenter(TreeSitterSegmenter):
|
|||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
"""Return the SQL language grammar for Tree-sitter."""
|
"""Return the SQL language grammar for Tree-sitter."""
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("sql")
|
return get_language("sql")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("sql")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
"""Return the Tree-sitter query for SQL segmentation."""
|
"""Return the Tree-sitter query for SQL segmentation."""
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
@ -6,7 +6,7 @@ from langchain_community.document_loaders.parsers.language.code_segmenter import
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language, Parser
|
from tree_sitter import Language, Node, Parser
|
||||||
|
|
||||||
|
|
||||||
class TreeSitterSegmenter(CodeSegmenter):
|
class TreeSitterSegmenter(CodeSegmenter):
|
||||||
@ -18,12 +18,12 @@ class TreeSitterSegmenter(CodeSegmenter):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import tree_sitter # noqa: F401
|
import tree_sitter # noqa: F401
|
||||||
import tree_sitter_languages # noqa: F401
|
import tree_sitter_language_pack # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import tree_sitter/tree_sitter_languages Python packages. "
|
"Could not import tree_sitter/tree_sitter_language_pack "
|
||||||
"Please install them with "
|
"Python packages. Please install them with "
|
||||||
"`pip install tree-sitter tree-sitter-languages`."
|
"`pip install tree-sitter tree-sitter-language-pack`."
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_valid(self) -> bool:
|
def is_valid(self) -> bool:
|
||||||
@ -35,48 +35,50 @@ class TreeSitterSegmenter(CodeSegmenter):
|
|||||||
|
|
||||||
return len(error_query.captures(tree.root_node)) == 0
|
return len(error_query.captures(tree.root_node)) == 0
|
||||||
|
|
||||||
def extract_functions_classes(self) -> List[str]:
|
def _get_top_level_nodes(self) -> List["Node"]:
|
||||||
language = self.get_language()
|
language = self.get_language()
|
||||||
query = language.query(self.get_chunk_query())
|
query = language.query(self.get_chunk_query())
|
||||||
|
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
|
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
|
||||||
captures = query.captures(tree.root_node)
|
captures = query.captures(tree.root_node)
|
||||||
|
top_level_nodes = {}
|
||||||
processed_lines = set()
|
for node_type, nodes in captures.items():
|
||||||
chunks = []
|
for node in nodes:
|
||||||
|
cursor = node.parent
|
||||||
for node, name in captures:
|
is_child = False
|
||||||
start_line = node.start_point[0]
|
while cursor is not None:
|
||||||
end_line = node.end_point[0]
|
if cursor.id in top_level_nodes:
|
||||||
lines = list(range(start_line, end_line + 1))
|
is_child = True
|
||||||
|
break
|
||||||
if any(line in processed_lines for line in lines):
|
cursor = cursor.parent
|
||||||
|
if is_child:
|
||||||
continue
|
continue
|
||||||
|
top_level_nodes[node.id] = node
|
||||||
|
|
||||||
processed_lines.update(lines)
|
children = node.children
|
||||||
chunk_text = node.text.decode("UTF-8")
|
for child in children:
|
||||||
chunks.append(chunk_text)
|
if child.id in top_level_nodes:
|
||||||
|
del top_level_nodes[child.id]
|
||||||
|
children.extend(child.children)
|
||||||
|
top_level_nodes_list = list(top_level_nodes.values())
|
||||||
|
top_level_nodes_list.sort(key=lambda n: n.start_point[0])
|
||||||
|
return top_level_nodes_list
|
||||||
|
|
||||||
return chunks
|
def extract_functions_classes(self) -> List[str]:
|
||||||
|
top_level_nodes = self._get_top_level_nodes()
|
||||||
|
return [
|
||||||
|
node.text.decode("UTF-8")
|
||||||
|
for node in top_level_nodes
|
||||||
|
if node.text is not None
|
||||||
|
]
|
||||||
|
|
||||||
def simplify_code(self) -> str:
|
def simplify_code(self) -> str:
|
||||||
language = self.get_language()
|
|
||||||
query = language.query(self.get_chunk_query())
|
|
||||||
|
|
||||||
parser = self.get_parser()
|
|
||||||
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
|
|
||||||
processed_lines = set()
|
|
||||||
|
|
||||||
simplified_lines = self.source_lines[:]
|
simplified_lines = self.source_lines[:]
|
||||||
for node, name in query.captures(tree.root_node):
|
top_level_nodes = self._get_top_level_nodes()
|
||||||
|
for node in top_level_nodes:
|
||||||
start_line = node.start_point[0]
|
start_line = node.start_point[0]
|
||||||
end_line = node.end_point[0]
|
end_line = node.end_point[0]
|
||||||
|
|
||||||
lines = list(range(start_line, end_line + 1))
|
|
||||||
if any(line in processed_lines for line in lines):
|
|
||||||
continue
|
|
||||||
|
|
||||||
simplified_lines[start_line] = self.make_line_comment(
|
simplified_lines[start_line] = self.make_line_comment(
|
||||||
f"Code for: {self.source_lines[start_line]}"
|
f"Code for: {self.source_lines[start_line]}"
|
||||||
)
|
)
|
||||||
@ -84,16 +86,11 @@ class TreeSitterSegmenter(CodeSegmenter):
|
|||||||
for line_num in range(start_line + 1, end_line + 1):
|
for line_num in range(start_line + 1, end_line + 1):
|
||||||
simplified_lines[line_num] = None # type: ignore
|
simplified_lines[line_num] = None # type: ignore
|
||||||
|
|
||||||
processed_lines.update(lines)
|
|
||||||
|
|
||||||
return "\n".join(line for line in simplified_lines if line is not None)
|
return "\n".join(line for line in simplified_lines if line is not None)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def get_parser(self) -> "Parser":
|
def get_parser(self) -> "Parser":
|
||||||
from tree_sitter import Parser
|
raise NotImplementedError()
|
||||||
|
|
||||||
parser = Parser()
|
|
||||||
parser.set_language(self.get_language())
|
|
||||||
return parser
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Language
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
CHUNK_QUERY = """
|
CHUNK_QUERY = """
|
||||||
@ -22,10 +22,15 @@ class TypeScriptSegmenter(TreeSitterSegmenter):
|
|||||||
"""Code segmenter for TypeScript."""
|
"""Code segmenter for TypeScript."""
|
||||||
|
|
||||||
def get_language(self) -> "Language":
|
def get_language(self) -> "Language":
|
||||||
from tree_sitter_languages import get_language
|
from tree_sitter_language_pack import get_language
|
||||||
|
|
||||||
return get_language("typescript")
|
return get_language("typescript")
|
||||||
|
|
||||||
|
def get_parser(self) -> "Parser":
|
||||||
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
|
return get_parser("typescript")
|
||||||
|
|
||||||
def get_chunk_query(self) -> str:
|
def get_chunk_query(self) -> str:
|
||||||
return CHUNK_QUERY
|
return CHUNK_QUERY
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.c import CSegmenter
|
from langchain_community.document_loaders.parsers.language.c import CSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestCSegmenter(unittest.TestCase):
|
class TestCSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """int main() {
|
self.example_code = """int main() {
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.cpp import CPPSegmenter
|
from langchain_community.document_loaders.parsers.language.cpp import CPPSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestCPPSegmenter(unittest.TestCase):
|
class TestCPPSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """int foo() {
|
self.example_code = """int foo() {
|
||||||
@ -55,9 +55,9 @@ auto T::bar() const -> int {
|
|||||||
def test_extract_functions_classes(self) -> None:
|
def test_extract_functions_classes(self) -> None:
|
||||||
segmenter = CPPSegmenter(self.example_code)
|
segmenter = CPPSegmenter(self.example_code)
|
||||||
extracted_code = segmenter.extract_functions_classes()
|
extracted_code = segmenter.extract_functions_classes()
|
||||||
self.assertEqual(extracted_code, self.expected_extracted_code)
|
self.assertEqual(self.expected_extracted_code, extracted_code)
|
||||||
|
|
||||||
def test_simplify_code(self) -> None:
|
def test_simplify_code(self) -> None:
|
||||||
segmenter = CPPSegmenter(self.example_code)
|
segmenter = CPPSegmenter(self.example_code)
|
||||||
simplified_code = segmenter.simplify_code()
|
simplified_code = segmenter.simplify_code()
|
||||||
self.assertEqual(simplified_code, self.expected_simplified_code)
|
self.assertEqual(self.expected_simplified_code, simplified_code)
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.csharp import CSharpSegmenter
|
from langchain_community.document_loaders.parsers.language.csharp import CSharpSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestCSharpSegmenter(unittest.TestCase):
|
class TestCSharpSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """namespace World
|
self.example_code = """namespace World
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.elixir import ElixirSegmenter
|
from langchain_community.document_loaders.parsers.language.elixir import ElixirSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestElixirSegmenter(unittest.TestCase):
|
class TestElixirSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """@doc "some comment"
|
self.example_code = """@doc "some comment"
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.go import GoSegmenter
|
from langchain_community.document_loaders.parsers.language.go import GoSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestGoSegmenter(unittest.TestCase):
|
class TestGoSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """func foo(a int) int {
|
self.example_code = """func foo(a int) int {
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.java import JavaSegmenter
|
from langchain_community.document_loaders.parsers.language.java import JavaSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestJavaSegmenter(unittest.TestCase):
|
class TestJavaSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """class Hello
|
self.example_code = """class Hello
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.kotlin import KotlinSegmenter
|
from langchain_community.document_loaders.parsers.language.kotlin import KotlinSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestKotlinSegmenter(unittest.TestCase):
|
class TestKotlinSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """fun foo(a: Int): Int {
|
self.example_code = """fun foo(a: Int): Int {
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.lua import LuaSegmenter
|
from langchain_community.document_loaders.parsers.language.lua import LuaSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestLuaSegmenter(unittest.TestCase):
|
class TestLuaSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """function F()
|
self.example_code = """function F()
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.perl import PerlSegmenter
|
from langchain_community.document_loaders.parsers.language.perl import PerlSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestPerlSegmenter(unittest.TestCase):
|
class TestPerlSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """sub Hello {
|
self.example_code = """sub Hello {
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.php import PHPSegmenter
|
from langchain_community.document_loaders.parsers.language.php import PHPSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestPHPSegmenter(unittest.TestCase):
|
class TestPHPSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """<?php
|
self.example_code = """<?php
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.ruby import RubySegmenter
|
from langchain_community.document_loaders.parsers.language.ruby import RubySegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestRubySegmenter(unittest.TestCase):
|
class TestRubySegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """def foo
|
self.example_code = """def foo
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.rust import RustSegmenter
|
from langchain_community.document_loaders.parsers.language.rust import RustSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestRustSegmenter(unittest.TestCase):
|
class TestRustSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """fn foo() -> i32 {
|
self.example_code = """fn foo() -> i32 {
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.scala import ScalaSegmenter
|
from langchain_community.document_loaders.parsers.language.scala import ScalaSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestScalaSegmenter(unittest.TestCase):
|
class TestScalaSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """def foo() {
|
self.example_code = """def foo() {
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter
|
from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestSQLSegmenter(unittest.TestCase):
|
class TestSQLSegmenter(unittest.TestCase):
|
||||||
"""Unit tests for the SQLSegmenter class."""
|
"""Unit tests for the SQLSegmenter class."""
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from langchain_community.document_loaders.parsers.language.typescript import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
|
@pytest.mark.requires("tree_sitter", "tree_sitter_language_pack")
|
||||||
class TestTypeScriptSegmenter(unittest.TestCase):
|
class TestTypeScriptSegmenter(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.example_code = """function foo(): number
|
self.example_code = """function foo(): number
|
||||||
|
Loading…
Reference in New Issue
Block a user