mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
Add COBOL parser and splitter (#11674)
- **Description:** Add COBOL parser and splitter - **Issue:** n/a - **Dependencies:** n/a - **Tag maintainer:** @baskaryan - **Twitter handle:** erhartford --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
bb137fd6e7
commit
8c150ad7f6
@ -0,0 +1,96 @@
|
||||
import re
|
||||
from typing import Callable, List
|
||||
|
||||
from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter
|
||||
|
||||
|
||||
class CobolSegmenter(CodeSegmenter):
|
||||
"""Code segmenter for `COBOL`."""
|
||||
|
||||
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
|
||||
DIVISION_PATTERN = re.compile(
|
||||
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
|
||||
)
|
||||
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(code)
|
||||
self.source_lines: List[str] = self.code.splitlines()
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
# Identify presence of any division to validate COBOL code
|
||||
return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines)
|
||||
|
||||
def _extract_code(self, start_idx: int, end_idx: int) -> str:
|
||||
return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n")
|
||||
|
||||
def _is_relevant_code(self, line: str) -> bool:
|
||||
"""Check if a line is part of the procedure division or a relevant section."""
|
||||
if "PROCEDURE DIVISION" in line.upper():
|
||||
return True
|
||||
# Add additional conditions for relevant sections if needed
|
||||
return False
|
||||
|
||||
def _process_lines(self, func: Callable) -> List[str]:
|
||||
"""A generic function to process COBOL lines based on provided func."""
|
||||
elements: List[str] = []
|
||||
start_idx = None
|
||||
inside_relevant_section = False
|
||||
|
||||
for i, line in enumerate(self.source_lines):
|
||||
if self._is_relevant_code(line):
|
||||
inside_relevant_section = True
|
||||
|
||||
if inside_relevant_section and (
|
||||
self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
|
||||
or self.SECTION_PATTERN.match(line.strip())
|
||||
):
|
||||
if start_idx is not None:
|
||||
func(elements, start_idx, i)
|
||||
start_idx = i
|
||||
|
||||
# Handle the last element if exists
|
||||
if start_idx is not None:
|
||||
func(elements, start_idx, len(self.source_lines))
|
||||
|
||||
return elements
|
||||
|
||||
def extract_functions_classes(self) -> List[str]:
|
||||
def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None:
|
||||
elements.append(self._extract_code(start_idx, end_idx))
|
||||
|
||||
return self._process_lines(extract_func)
|
||||
|
||||
def simplify_code(self) -> str:
|
||||
simplified_lines: List[str] = []
|
||||
inside_relevant_section = False
|
||||
omitted_code_added = (
|
||||
False # To track if "* OMITTED CODE *" has been added after the last header
|
||||
)
|
||||
|
||||
for line in self.source_lines:
|
||||
is_header = (
|
||||
"PROCEDURE DIVISION" in line
|
||||
or "DATA DIVISION" in line
|
||||
or "IDENTIFICATION DIVISION" in line
|
||||
or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
|
||||
or self.SECTION_PATTERN.match(line.strip())
|
||||
)
|
||||
|
||||
if is_header:
|
||||
inside_relevant_section = True
|
||||
# Reset the flag since we're entering a new section/division or
|
||||
# paragraph
|
||||
omitted_code_added = False
|
||||
|
||||
if inside_relevant_section:
|
||||
if is_header:
|
||||
# Add header and reset the omitted code added flag
|
||||
simplified_lines.append(line)
|
||||
elif not omitted_code_added:
|
||||
# Add omitted code comment only if it hasn't been added directly
|
||||
# after the last header
|
||||
simplified_lines.append("* OMITTED CODE *")
|
||||
omitted_code_added = True
|
||||
|
||||
return "\n".join(simplified_lines)
|
@ -3,6 +3,7 @@ from typing import Any, Dict, Iterator, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.document_loaders.parsers.language.cobol import CobolSegmenter
|
||||
from langchain.document_loaders.parsers.language.javascript import JavaScriptSegmenter
|
||||
from langchain.document_loaders.parsers.language.python import PythonSegmenter
|
||||
from langchain.text_splitter import Language
|
||||
@ -10,11 +11,13 @@ from langchain.text_splitter import Language
|
||||
LANGUAGE_EXTENSIONS: Dict[str, str] = {
|
||||
"py": Language.PYTHON,
|
||||
"js": Language.JS,
|
||||
"cobol": Language.COBOL,
|
||||
}
|
||||
|
||||
LANGUAGE_SEGMENTERS: Dict[str, Any] = {
|
||||
Language.PYTHON: PythonSegmenter,
|
||||
Language.JS: JavaScriptSegmenter,
|
||||
Language.COBOL: CobolSegmenter,
|
||||
}
|
||||
|
||||
|
||||
|
@ -811,6 +811,7 @@ class Language(str, Enum):
|
||||
HTML = "html"
|
||||
SOL = "sol"
|
||||
CSHARP = "csharp"
|
||||
COBOL = "cobol"
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
@ -1305,6 +1306,38 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.COBOL:
|
||||
return [
|
||||
# Split along divisions
|
||||
"\nIDENTIFICATION DIVISION.",
|
||||
"\nENVIRONMENT DIVISION.",
|
||||
"\nDATA DIVISION.",
|
||||
"\nPROCEDURE DIVISION.",
|
||||
# Split along sections within DATA DIVISION
|
||||
"\nWORKING-STORAGE SECTION.",
|
||||
"\nLINKAGE SECTION.",
|
||||
"\nFILE SECTION.",
|
||||
# Split along sections within PROCEDURE DIVISION
|
||||
"\nINPUT-OUTPUT SECTION.",
|
||||
# Split along paragraphs and common statements
|
||||
"\nOPEN ",
|
||||
"\nCLOSE ",
|
||||
"\nREAD ",
|
||||
"\nWRITE ",
|
||||
"\nIF ",
|
||||
"\nELSE ",
|
||||
"\nMOVE ",
|
||||
"\nPERFORM ",
|
||||
"\nUNTIL ",
|
||||
"\nVARYING ",
|
||||
"\nACCEPT ",
|
||||
"\nDISPLAY ",
|
||||
"\nSTOP RUN.",
|
||||
# Split by the normal type of lines
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -0,0 +1,49 @@
|
||||
from langchain.document_loaders.parsers.language.cobol import CobolSegmenter
|
||||
|
||||
EXAMPLE_CODE = """
|
||||
IDENTIFICATION DIVISION.
|
||||
PROGRAM-ID. SampleProgram.
|
||||
DATA DIVISION.
|
||||
WORKING-STORAGE SECTION.
|
||||
01 SAMPLE-VAR PIC X(20) VALUE 'Sample Value'.
|
||||
|
||||
PROCEDURE DIVISION.
|
||||
A000-INITIALIZE-PARA.
|
||||
DISPLAY 'Initialization Paragraph'.
|
||||
MOVE 'New Value' TO SAMPLE-VAR.
|
||||
|
||||
A100-PROCESS-PARA.
|
||||
DISPLAY SAMPLE-VAR.
|
||||
STOP RUN.
|
||||
"""
|
||||
|
||||
|
||||
def test_extract_functions_classes() -> None:
|
||||
"""Test that functions and classes are extracted correctly."""
|
||||
segmenter = CobolSegmenter(EXAMPLE_CODE)
|
||||
extracted_code = segmenter.extract_functions_classes()
|
||||
assert extracted_code == [
|
||||
"A000-INITIALIZE-PARA.\n "
|
||||
"DISPLAY 'Initialization Paragraph'.\n "
|
||||
"MOVE 'New Value' TO SAMPLE-VAR.",
|
||||
"A100-PROCESS-PARA.\n DISPLAY SAMPLE-VAR.\n STOP RUN.",
|
||||
]
|
||||
|
||||
|
||||
def test_simplify_code() -> None:
|
||||
"""Test that code is simplified correctly."""
|
||||
expected_simplified_code = (
|
||||
"IDENTIFICATION DIVISION.\n"
|
||||
"PROGRAM-ID. SampleProgram.\n"
|
||||
"DATA DIVISION.\n"
|
||||
"WORKING-STORAGE SECTION.\n"
|
||||
"* OMITTED CODE *\n"
|
||||
"PROCEDURE DIVISION.\n"
|
||||
"A000-INITIALIZE-PARA.\n"
|
||||
"* OMITTED CODE *\n"
|
||||
"A100-PROCESS-PARA.\n"
|
||||
"* OMITTED CODE *\n"
|
||||
)
|
||||
segmenter = CobolSegmenter(EXAMPLE_CODE)
|
||||
simplified_code = segmenter.simplify_code()
|
||||
assert simplified_code.strip() == expected_simplified_code.strip()
|
@ -472,6 +472,41 @@ helloWorld();
|
||||
]
|
||||
|
||||
|
||||
def test_cobol_code_splitter() -> None:
|
||||
splitter = RecursiveCharacterTextSplitter.from_language(
|
||||
Language.COBOL, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
)
|
||||
code = """
|
||||
IDENTIFICATION DIVISION.
|
||||
PROGRAM-ID. HelloWorld.
|
||||
DATA DIVISION.
|
||||
WORKING-STORAGE SECTION.
|
||||
01 GREETING PIC X(12) VALUE 'Hello, World!'.
|
||||
PROCEDURE DIVISION.
|
||||
DISPLAY GREETING.
|
||||
STOP RUN.
|
||||
"""
|
||||
chunks = splitter.split_text(code)
|
||||
assert chunks == [
|
||||
"IDENTIFICATION",
|
||||
"DIVISION.",
|
||||
"PROGRAM-ID.",
|
||||
"HelloWorld.",
|
||||
"DATA DIVISION.",
|
||||
"WORKING-STORAGE",
|
||||
"SECTION.",
|
||||
"01 GREETING",
|
||||
"PIC X(12)",
|
||||
"VALUE 'Hello,",
|
||||
"World!'.",
|
||||
"PROCEDURE",
|
||||
"DIVISION.",
|
||||
"DISPLAY",
|
||||
"GREETING.",
|
||||
"STOP RUN.",
|
||||
]
|
||||
|
||||
|
||||
def test_typescript_code_splitter() -> None:
|
||||
splitter = RecursiveCharacterTextSplitter.from_language(
|
||||
Language.TS, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
||||
|
Loading…
Reference in New Issue
Block a user