Framework for supporting more languages in LanguageParser (#13318)

## Description

I am submitting this for a school project as part of a team of 5. Other
team members are @LeilaChr, @maazh10, @Megabear137, @jelalalamy. This PR
also has contributions from community members @Harrolee and @Mario928.

Initial context is in the issue we opened (#11229).

This pull request adds:

- Generic framework for expanding the languages that `LanguageParser`
can handle, using the
[tree-sitter](https://github.com/tree-sitter/py-tree-sitter#py-tree-sitter)
parsing library and existing language-specific parsers written for it
- Support for the following additional languages in `LanguageParser`:
  - C
  - C++
  - C#
  - Go
- Java (contributed by @Mario928
https://github.com/ThatsJustCheesy/langchain/pull/2)
  - Kotlin
  - Lua
  - Perl
  - Ruby
  - Rust
  - Scala
- TypeScript (contributed by @Harrolee
https://github.com/ThatsJustCheesy/langchain/pull/1)

Here is the [design
document](https://docs.google.com/document/d/17dB14cKCWAaiTeSeBtxHpoVPGKrsPye8W0o_WClz2kk)
if curious, but no need to read it.

## Issues

- Closes #11229
- Closes #10996
- Closes #8405

## Dependencies

`tree_sitter` and `tree_sitter_languages` on PyPI. We have tried to add
these as optional dependencies.

## Documentation

We have updated the list of supported languages, and also added a
section to `source_code.ipynb` detailing how to add support for
additional languages using our framework.

## Maintainer

- @hwchase17 (previously reviewed
https://github.com/langchain-ai/langchain/pull/6486)

Thanks!!

## Git commits

We will gladly squash any/all of our commits (esp merge commits) if
necessary. Let us know if this is desirable, or if you will be
squash-merging anyway.

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Maaz Hashmi <mhashmi373@gmail.com>
Co-authored-by: LeilaChr <87657694+LeilaChr@users.noreply.github.com>
Co-authored-by: Jeremy La <jeremylai511@gmail.com>
Co-authored-by: Megabear137 <zubair.alnoor27@gmail.com>
Co-authored-by: Lee Harrold <lhharrold@sep.com>
Co-authored-by: Mario928 <88029051+Mario928@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Ian Gregory
2024-02-13 11:45:49 -05:00
committed by GitHub
parent 729c6d6827
commit e5472b5eb8
29 changed files with 1464 additions and 13 deletions

View File

@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(struct_specifier
body: (field_declaration_list)) @struct
(enum_specifier
body: (enumerator_list)) @enum
(union_specifier
body: (field_declaration_list)) @union
(function_definition) @function
]
""".strip()
class CSegmenter(TreeSitterSegmenter):
"""Code segmenter for C."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("c")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"

View File

@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(class_specifier
body: (field_declaration_list)) @class
(struct_specifier
body: (field_declaration_list)) @struct
(union_specifier
body: (field_declaration_list)) @union
(function_definition) @function
]
""".strip()
class CPPSegmenter(TreeSitterSegmenter):
"""Code segmenter for C++."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("cpp")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"

View File

@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(namespace_declaration) @namespace
(class_declaration) @class
(method_declaration) @method
(interface_declaration) @interface
(enum_declaration) @enum
(struct_declaration) @struct
(record_declaration) @record
]
""".strip()
class CSharpSegmenter(TreeSitterSegmenter):
"""Code segmenter for C#."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("c_sharp")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"

View File

@@ -0,0 +1,31 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(function_declaration) @function
(type_declaration) @type
]
""".strip()
class GoSegmenter(TreeSitterSegmenter):
"""Code segmenter for Go."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("go")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"

View File

@@ -0,0 +1,32 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(class_declaration) @class
(interface_declaration) @interface
(enum_declaration) @enum
]
""".strip()
class JavaSegmenter(TreeSitterSegmenter):
"""Code segmenter for Java."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("java")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"

View File

@@ -0,0 +1,31 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(function_declaration) @function
(class_declaration) @class
]
""".strip()
class KotlinSegmenter(TreeSitterSegmenter):
"""Code segmenter for Kotlin."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("kotlin")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"

View File

@@ -6,28 +6,66 @@ from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseBlobParser
from langchain_community.document_loaders.blob_loaders import Blob
from langchain_community.document_loaders.parsers.language.c import CSegmenter
from langchain_community.document_loaders.parsers.language.cobol import CobolSegmenter
from langchain_community.document_loaders.parsers.language.cpp import CPPSegmenter
from langchain_community.document_loaders.parsers.language.csharp import CSharpSegmenter
from langchain_community.document_loaders.parsers.language.go import GoSegmenter
from langchain_community.document_loaders.parsers.language.java import JavaSegmenter
from langchain_community.document_loaders.parsers.language.javascript import (
JavaScriptSegmenter,
)
from langchain_community.document_loaders.parsers.language.kotlin import KotlinSegmenter
from langchain_community.document_loaders.parsers.language.lua import LuaSegmenter
from langchain_community.document_loaders.parsers.language.perl import PerlSegmenter
from langchain_community.document_loaders.parsers.language.python import PythonSegmenter
from langchain_community.document_loaders.parsers.language.ruby import RubySegmenter
from langchain_community.document_loaders.parsers.language.rust import RustSegmenter
from langchain_community.document_loaders.parsers.language.scala import ScalaSegmenter
from langchain_community.document_loaders.parsers.language.typescript import (
TypeScriptSegmenter,
)
if TYPE_CHECKING:
from langchain.text_splitter import Language
from langchain.langchain.text_splitter import Language
try:
from langchain.text_splitter import Language
from langchain.langchain.text_splitter import Language
LANGUAGE_EXTENSIONS: Dict[str, str] = {
"py": Language.PYTHON,
"js": Language.JS,
"cobol": Language.COBOL,
"c": Language.C,
"cpp": Language.CPP,
"cs": Language.CSHARP,
"rb": Language.RUBY,
"scala": Language.SCALA,
"rs": Language.RUST,
"go": Language.GO,
"kt": Language.KOTLIN,
"lua": Language.LUA,
"pl": Language.PERL,
"ts": Language.TS,
"java": Language.JAVA,
}
LANGUAGE_SEGMENTERS: Dict[str, Any] = {
Language.PYTHON: PythonSegmenter,
Language.JS: JavaScriptSegmenter,
Language.COBOL: CobolSegmenter,
Language.C: CSegmenter,
Language.CPP: CPPSegmenter,
Language.CSHARP: CSharpSegmenter,
Language.RUBY: RubySegmenter,
Language.RUST: RustSegmenter,
Language.SCALA: ScalaSegmenter,
Language.GO: GoSegmenter,
Language.KOTLIN: KotlinSegmenter,
Language.LUA: LuaSegmenter,
Language.PERL: PerlSegmenter,
Language.TS: TypeScriptSegmenter,
Language.JAVA: JavaSegmenter,
}
except ImportError:
LANGUAGE_EXTENSIONS = {}
@@ -43,11 +81,34 @@ class LanguageParser(BaseBlobParser):
This approach can potentially improve the accuracy of QA models over source code.
Currently, the supported languages for code parsing are Python and JavaScript.
The supported languages for code parsing are:
- C (*)
- C++ (*)
- C# (*)
- COBOL
- Go (*)
- Java (*)
- JavaScript (requires package `esprima`)
- Kotlin (*)
- Lua (*)
- Perl (*)
- Python
- Ruby (*)
- Rust (*)
- Scala (*)
- TypeScript (*)
Items marked with (*) require the packages `tree_sitter` and
`tree_sitter_languages`. It is straightforward to add support for additional
languages using `tree_sitter`, although this currently requires modifying LangChain.
The language used for parsing can be configured, along with the minimum number of
lines required to activate the splitting based on syntax.
If a language is not explicitly specified, `LanguageParser` will infer one from
filename extensions, if present.
Examples:
.. code-block:: python

View File

@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(function_definition_statement
name: (identifier)) @function
(local_function_definition_statement
name: (identifier)) @function
]
""".strip()
class LuaSegmenter(TreeSitterSegmenter):
"""Code segmenter for Lua."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("lua")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"-- {text}"

View File

@@ -0,0 +1,30 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(function_definition) @subroutine
]
""".strip()
class PerlSegmenter(TreeSitterSegmenter):
"""Code segmenter for Perl."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("perl")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"# {text}"

View File

@@ -0,0 +1,32 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(method) @method
(module) @module
(class) @class
]
""".strip()
class RubySegmenter(TreeSitterSegmenter):
"""Code segmenter for Ruby."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("ruby")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"# {text}"

View File

@@ -0,0 +1,34 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(function_item
name: (identifier)
body: (block)) @function
(struct_item) @struct
(trait_item) @trait
]
""".strip()
class RustSegmenter(TreeSitterSegmenter):
"""Code segmenter for Rust."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("rust")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"

View File

@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(class_definition) @class
(function_definition) @function
(object_definition) @object
(trait_definition) @trait
]
""".strip()
class ScalaSegmenter(TreeSitterSegmenter):
"""Code segmenter for Scala."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("scala")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"

View File

@@ -0,0 +1,108 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, List
from langchain_community.document_loaders.parsers.language.code_segmenter import (
CodeSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language, Parser
class TreeSitterSegmenter(CodeSegmenter):
"""Abstract class for `CodeSegmenter`s that use the tree-sitter library."""
def __init__(self, code: str):
super().__init__(code)
self.source_lines = self.code.splitlines()
try:
import tree_sitter # noqa: F401
import tree_sitter_languages # noqa: F401
except ImportError:
raise ImportError(
"Could not import tree_sitter/tree_sitter_languages Python packages. "
"Please install them with "
"`pip install tree-sitter tree-sitter-languages`."
)
def is_valid(self) -> bool:
language = self.get_language()
error_query = language.query("(ERROR) @error")
parser = self.get_parser()
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
return len(error_query.captures(tree.root_node)) == 0
def extract_functions_classes(self) -> List[str]:
language = self.get_language()
query = language.query(self.get_chunk_query())
parser = self.get_parser()
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
captures = query.captures(tree.root_node)
processed_lines = set()
chunks = []
for node, name in captures:
start_line = node.start_point[0]
end_line = node.end_point[0]
lines = list(range(start_line, end_line + 1))
if any(line in processed_lines for line in lines):
continue
processed_lines.update(lines)
chunk_text = node.text.decode("UTF-8")
chunks.append(chunk_text)
return chunks
def simplify_code(self) -> str:
language = self.get_language()
query = language.query(self.get_chunk_query())
parser = self.get_parser()
tree = parser.parse(bytes(self.code, encoding="UTF-8"))
processed_lines = set()
simplified_lines = self.source_lines[:]
for node, name in query.captures(tree.root_node):
start_line = node.start_point[0]
end_line = node.end_point[0]
lines = list(range(start_line, end_line + 1))
if any(line in processed_lines for line in lines):
continue
simplified_lines[start_line] = self.make_line_comment(
f"Code for: {self.source_lines[start_line]}"
)
for line_num in range(start_line + 1, end_line + 1):
simplified_lines[line_num] = None # type: ignore
processed_lines.update(lines)
return "\n".join(line for line in simplified_lines if line is not None)
def get_parser(self) -> "Parser":
from tree_sitter import Parser
parser = Parser()
parser.set_language(self.get_language())
return parser
@abstractmethod
def get_language(self) -> "Language":
raise NotImplementedError() # pragma: no cover
@abstractmethod
def get_chunk_query(self) -> str:
raise NotImplementedError() # pragma: no cover
@abstractmethod
def make_line_comment(self, text: str) -> str:
raise NotImplementedError() # pragma: no cover

View File

@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)
if TYPE_CHECKING:
from tree_sitter import Language
CHUNK_QUERY = """
[
(function_declaration) @function
(class_declaration) @class
(interface_declaration) @interface
(enum_declaration) @enum
]
""".strip()
class TypeScriptSegmenter(TreeSitterSegmenter):
"""Code segmenter for TypeScript."""
def get_language(self) -> "Language":
from tree_sitter_languages import get_language
return get_language("typescript")
def get_chunk_query(self) -> str:
return CHUNK_QUERY
def make_line_comment(self, text: str) -> str:
return f"// {text}"