mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
feat (documents): add a source code loader based on AST manipulation (#6486)
#### Summary A new approach to loading source code is implemented: Each top-level function and class in the code is loaded into separate documents. Then, an additional document is created with the top-level code, but without the already loaded functions and classes. This could improve the accuracy of QA chains over source code. For instance, having this script: ``` class MyClass: def __init__(self, name): self.name = name def greet(self): print(f"Hello, {self.name}!") def main(): name = input("Enter your name: ") obj = MyClass(name) obj.greet() if __name__ == '__main__': main() ``` The loader will create three documents with this content: First document: ``` class MyClass: def __init__(self, name): self.name = name def greet(self): print(f"Hello, {self.name}!") ``` Second document: ``` def main(): name = input("Enter your name: ") obj = MyClass(name) obj.greet() ``` Third document: ``` # Code for: class MyClass: # Code for: def main(): if __name__ == '__main__': main() ``` A threshold parameter is added to control whether small scripts are split in this way or not. At this moment, only Python and JavaScript are supported. The appropriate parser is determined by examining the file extension. #### Tests This PR adds: - Unit tests - Integration tests #### Dependencies Only one dependency was added as optional (needed for the JavaScript parser). #### Documentation A notebook is added showing how the loader can be used. #### Who can review? @eyurtsev @hwchase17 --------- Co-authored-by: rlm <pexpresss31@gmail.com>
This commit is contained in:
committed by
GitHub
parent
da462d9dd4
commit
e494b0a09f
@@ -0,0 +1,133 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.document_loaders.generic import GenericLoader
|
||||
from langchain.document_loaders.parsers import LanguageParser
|
||||
from langchain.text_splitter import Language
|
||||
|
||||
|
||||
def test_language_loader_for_python() -> None:
|
||||
"""Test Python loader with parser enabled."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = GenericLoader.from_filesystem(
|
||||
file_path, glob="hello_world.py", parser=LanguageParser(parser_threshold=5)
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 2
|
||||
|
||||
metadata = docs[0].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.py")
|
||||
assert metadata["content_type"] == "functions_classes"
|
||||
assert metadata["language"] == "python"
|
||||
metadata = docs[1].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.py")
|
||||
assert metadata["content_type"] == "simplified_code"
|
||||
assert metadata["language"] == "python"
|
||||
|
||||
assert (
|
||||
docs[0].page_content
|
||||
== """def main():
|
||||
print("Hello World!")
|
||||
|
||||
return 0"""
|
||||
)
|
||||
assert (
|
||||
docs[1].page_content
|
||||
== """#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
# Code for: def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())"""
|
||||
)
|
||||
|
||||
|
||||
def test_language_loader_for_python_with_parser_threshold() -> None:
|
||||
"""Test Python loader with parser enabled and below threshold."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = GenericLoader.from_filesystem(
|
||||
file_path,
|
||||
glob="hello_world.py",
|
||||
parser=LanguageParser(language=Language.PYTHON, parser_threshold=1000),
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def esprima_installed() -> bool:
|
||||
try:
|
||||
import esprima # noqa: F401
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"esprima not installed, skipping test {e}")
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not esprima_installed(), reason="requires esprima package")
|
||||
def test_language_loader_for_javascript() -> None:
|
||||
"""Test JavaScript loader with parser enabled."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = GenericLoader.from_filesystem(
|
||||
file_path, glob="hello_world.js", parser=LanguageParser(parser_threshold=5)
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 3
|
||||
|
||||
metadata = docs[0].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.js")
|
||||
assert metadata["content_type"] == "functions_classes"
|
||||
assert metadata["language"] == "js"
|
||||
metadata = docs[1].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.js")
|
||||
assert metadata["content_type"] == "functions_classes"
|
||||
assert metadata["language"] == "js"
|
||||
metadata = docs[2].metadata
|
||||
assert metadata["source"] == str(file_path / "hello_world.js")
|
||||
assert metadata["content_type"] == "simplified_code"
|
||||
assert metadata["language"] == "js"
|
||||
|
||||
assert (
|
||||
docs[0].page_content
|
||||
== """class HelloWorld {
|
||||
sayHello() {
|
||||
console.log("Hello World!");
|
||||
}
|
||||
}"""
|
||||
)
|
||||
assert (
|
||||
docs[1].page_content
|
||||
== """function main() {
|
||||
const hello = new HelloWorld();
|
||||
hello.sayHello();
|
||||
}"""
|
||||
)
|
||||
assert (
|
||||
docs[2].page_content
|
||||
== """// Code for: class HelloWorld {
|
||||
|
||||
// Code for: function main() {
|
||||
|
||||
main();"""
|
||||
)
|
||||
|
||||
|
||||
def test_language_loader_for_javascript_with_parser_threshold() -> None:
|
||||
"""Test JavaScript loader with parser enabled and below threshold."""
|
||||
file_path = Path(__file__).parent.parent.parent / "examples"
|
||||
loader = GenericLoader.from_filesystem(
|
||||
file_path,
|
||||
glob="hello_world.js",
|
||||
parser=LanguageParser(language=Language.JS, parser_threshold=1000),
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
12
tests/integration_tests/examples/hello_world.js
Normal file
12
tests/integration_tests/examples/hello_world.js
Normal file
@@ -0,0 +1,12 @@
|
||||
class HelloWorld {
|
||||
sayHello() {
|
||||
console.log("Hello World!");
|
||||
}
|
||||
}
|
||||
|
||||
function main() {
|
||||
const hello = new HelloWorld();
|
||||
hello.sayHello();
|
||||
}
|
||||
|
||||
main();
|
13
tests/integration_tests/examples/hello_world.py
Normal file
13
tests/integration_tests/examples/hello_world.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
print("Hello World!")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
@@ -0,0 +1,46 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.document_loaders.parsers.language.javascript import JavaScriptSegmenter
|
||||
|
||||
|
||||
@pytest.mark.requires("esprima")
|
||||
class TestJavaScriptSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """const os = require('os');
|
||||
|
||||
function hello(text) {
|
||||
console.log(text);
|
||||
}
|
||||
|
||||
class Simple {
|
||||
constructor() {
|
||||
this.a = 1;
|
||||
}
|
||||
}
|
||||
|
||||
hello("Hello!");"""
|
||||
|
||||
self.expected_simplified_code = """const os = require('os');
|
||||
|
||||
// Code for: function hello(text) {
|
||||
|
||||
// Code for: class Simple {
|
||||
|
||||
hello("Hello!");"""
|
||||
|
||||
self.expected_extracted_code = [
|
||||
"function hello(text) {\n console.log(text);\n}",
|
||||
"class Simple {\n constructor() {\n this.a = 1;\n }\n}",
|
||||
]
|
||||
|
||||
def test_extract_functions_classes(self) -> None:
|
||||
segmenter = JavaScriptSegmenter(self.example_code)
|
||||
extracted_code = segmenter.extract_functions_classes()
|
||||
self.assertEqual(extracted_code, self.expected_extracted_code)
|
||||
|
||||
def test_simplify_code(self) -> None:
|
||||
segmenter = JavaScriptSegmenter(self.example_code)
|
||||
simplified_code = segmenter.simplify_code()
|
||||
self.assertEqual(simplified_code, self.expected_simplified_code)
|
@@ -0,0 +1,40 @@
|
||||
import unittest
|
||||
|
||||
from langchain.document_loaders.parsers.language.python import PythonSegmenter
|
||||
|
||||
|
||||
class TestPythonSegmenter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.example_code = """import os
|
||||
|
||||
def hello(text):
|
||||
print(text)
|
||||
|
||||
class Simple:
|
||||
def __init__(self):
|
||||
self.a = 1
|
||||
|
||||
hello("Hello!")"""
|
||||
|
||||
self.expected_simplified_code = """import os
|
||||
|
||||
# Code for: def hello(text):
|
||||
|
||||
# Code for: class Simple:
|
||||
|
||||
hello("Hello!")"""
|
||||
|
||||
self.expected_extracted_code = [
|
||||
"def hello(text):\n" " print(text)",
|
||||
"class Simple:\n" " def __init__(self):\n" " self.a = 1",
|
||||
]
|
||||
|
||||
def test_extract_functions_classes(self) -> None:
|
||||
segmenter = PythonSegmenter(self.example_code)
|
||||
extracted_code = segmenter.extract_functions_classes()
|
||||
self.assertEqual(extracted_code, self.expected_extracted_code)
|
||||
|
||||
def test_simplify_code(self) -> None:
|
||||
segmenter = PythonSegmenter(self.example_code)
|
||||
simplified_code = segmenter.simplify_code()
|
||||
self.assertEqual(simplified_code, self.expected_simplified_code)
|
@@ -5,6 +5,7 @@ def test_parsers_public_api_correct() -> None:
|
||||
"""Test public API of parsers for breaking changes."""
|
||||
assert set(__all__) == {
|
||||
"BS4HTMLParser",
|
||||
"LanguageParser",
|
||||
"OpenAIWhisperParser",
|
||||
"PyPDFParser",
|
||||
"PDFMinerParser",
|
||||
|
Reference in New Issue
Block a user