mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
Bagatur/revert revert nuclia (#8833)
This commit is contained in:
33
libs/langchain/langchain/document_loaders/nuclia.py
Normal file
33
libs/langchain/langchain/document_loaders/nuclia.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Extract text from any file type."""
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
|
||||
|
||||
|
||||
class NucliaLoader(BaseLoader):
|
||||
"""Extract text from any file type."""
|
||||
|
||||
def __init__(self, path: str, nuclia_tool: NucliaUnderstandingAPI):
|
||||
self.nua = nuclia_tool
|
||||
self.id = str(uuid.uuid4())
|
||||
self.nua.run({"action": "push", "id": self.id, "path": path, "text": None})
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load documents."""
|
||||
data = self.nua.run(
|
||||
{"action": "pull", "id": self.id, "path": None, "text": None}
|
||||
)
|
||||
if not data:
|
||||
return []
|
||||
obj = json.loads(data)
|
||||
text = obj["extracted_text"][0]["body"]["text"]
|
||||
print(text)
|
||||
metadata = {
|
||||
"file": obj["file_extracted_data"][0],
|
||||
"metadata": obj["field_metadata"][0],
|
||||
}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
@@ -27,6 +27,7 @@ from langchain.document_transformers.embeddings_redundant_filter import (
|
||||
)
|
||||
from langchain.document_transformers.html2text import Html2TextTransformer
|
||||
from langchain.document_transformers.long_context_reorder import LongContextReorder
|
||||
from langchain.document_transformers.nuclia_text_transform import NucliaTextTransformer
|
||||
from langchain.document_transformers.openai_functions import OpenAIMetadataTagger
|
||||
|
||||
__all__ = [
|
||||
@@ -37,6 +38,7 @@ __all__ = [
|
||||
"EmbeddingsRedundantFilter",
|
||||
"get_stateful_documents",
|
||||
"LongContextReorder",
|
||||
"NucliaTextTransformer",
|
||||
"OpenAIMetadataTagger",
|
||||
"Html2TextTransformer",
|
||||
]
|
||||
|
@@ -0,0 +1,47 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Sequence
|
||||
|
||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
|
||||
|
||||
|
||||
class NucliaTextTransformer(BaseDocumentTransformer):
|
||||
"""
|
||||
The Nuclia Understanding API splits into paragraphs and sentences,
|
||||
identifies entities, provides a summary of the text and generates
|
||||
embeddings for all the sentences.
|
||||
"""
|
||||
|
||||
def __init__(self, nua: NucliaUnderstandingAPI):
|
||||
self.nua = nua
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
tasks = [
|
||||
self.nua.arun(
|
||||
{
|
||||
"action": "push",
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": doc.page_content,
|
||||
"path": None,
|
||||
}
|
||||
)
|
||||
for doc in documents
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
for doc, result in zip(documents, results):
|
||||
obj = json.loads(result)
|
||||
metadata = {
|
||||
"file": obj["file_extracted_data"][0],
|
||||
"metadata": obj["field_metadata"][0],
|
||||
}
|
||||
doc.metadata["nuclia"] = metadata
|
||||
return documents
|
3
libs/langchain/langchain/tools/nuclia/__init__.py
Normal file
3
libs/langchain/langchain/tools/nuclia/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
|
||||
|
||||
__all__ = ["NucliaUnderstandingAPI"]
|
229
libs/langchain/langchain/tools/nuclia/tool.py
Normal file
229
libs/langchain/langchain/tools/nuclia/tool.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Tool for the Nuclia Understanding API.
|
||||
|
||||
Installation:
|
||||
|
||||
```bash
|
||||
pip install --upgrade protobuf
|
||||
pip install nucliadb-protos
|
||||
```
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NUASchema(BaseModel):
|
||||
action: str = Field(
|
||||
...,
|
||||
description="Action to perform. Either `push` or `pull`.",
|
||||
)
|
||||
id: str = Field(
|
||||
...,
|
||||
description="ID of the file to push or pull.",
|
||||
)
|
||||
path: Optional[str] = Field(
|
||||
...,
|
||||
description="Path to the file to push (needed only for `push` action).",
|
||||
)
|
||||
text: Optional[str] = Field(
|
||||
...,
|
||||
description="Text content to process (needed only for `push` action).",
|
||||
)
|
||||
|
||||
|
||||
class NucliaUnderstandingAPI(BaseTool):
|
||||
"""Tool to process files with the Nuclia Understanding API."""
|
||||
|
||||
name = "nuclia_understanding_api"
|
||||
description = (
|
||||
"A wrapper around Nuclia Understanding API endpoints. "
|
||||
"Useful for when you need to extract text from any kind of files. "
|
||||
)
|
||||
args_schema: Type[BaseModel] = NUASchema
|
||||
_results: Dict[str, Any] = {}
|
||||
_config: Dict[str, Any] = {}
|
||||
|
||||
def __init__(self, enable_ml: bool = False) -> None:
|
||||
zone = os.environ.get("NUCLIA_ZONE", "europe-1")
|
||||
self._config["BACKEND"] = f"https://{zone}.nuclia.cloud/api/v1"
|
||||
key = os.environ.get("NUCLIA_NUA_KEY")
|
||||
if not key:
|
||||
raise ValueError("NUCLIA_NUA_KEY environment variable not set")
|
||||
else:
|
||||
self._config["NUA_KEY"] = key
|
||||
self._config["enable_ml"] = enable_ml
|
||||
super().__init__()
|
||||
|
||||
def _run(
|
||||
self,
|
||||
action: str,
|
||||
id: str,
|
||||
path: Optional[str],
|
||||
text: Optional[str],
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
if action == "push":
|
||||
self._check_params(path, text)
|
||||
if path:
|
||||
return self._pushFile(id, path)
|
||||
if text:
|
||||
return self._pushText(id, text)
|
||||
elif action == "pull":
|
||||
return self._pull(id)
|
||||
return ""
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
action: str,
|
||||
id: str,
|
||||
path: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
self._check_params(path, text)
|
||||
if path:
|
||||
self._pushFile(id, path)
|
||||
if text:
|
||||
self._pushText(id, text)
|
||||
data = None
|
||||
while True:
|
||||
data = self._pull(id)
|
||||
if data:
|
||||
break
|
||||
await asyncio.sleep(15)
|
||||
return data
|
||||
|
||||
def _pushText(self, id: str, text: str) -> str:
|
||||
field = {
|
||||
"textfield": {"text": {"body": text, "format": 0}},
|
||||
"processing_options": {"ml_text": self._config["enable_ml"]},
|
||||
}
|
||||
return self._pushField(id, field)
|
||||
|
||||
def _pushFile(self, id: str, content_path: str) -> str:
|
||||
with open(content_path, "rb") as source_file:
|
||||
response = requests.post(
|
||||
self._config["BACKEND"] + "/processing/upload",
|
||||
headers={
|
||||
"content-type": mimetypes.guess_type(content_path)[0]
|
||||
or "application/octet-stream",
|
||||
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
|
||||
},
|
||||
data=source_file.read(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.info(
|
||||
f"Error uploading {content_path}: "
|
||||
f"{response.status_code} {response.text}"
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
field = {
|
||||
"filefield": {"file": f"{response.text}"},
|
||||
"processing_options": {"ml_text": self._config["enable_ml"]},
|
||||
}
|
||||
return self._pushField(id, field)
|
||||
|
||||
def _pushField(self, id: str, field: Any) -> str:
|
||||
logger.info(f"Pushing {id} in queue")
|
||||
response = requests.post(
|
||||
self._config["BACKEND"] + "/processing/push",
|
||||
headers={
|
||||
"content-type": "application/json",
|
||||
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
|
||||
},
|
||||
json=field,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.info(
|
||||
f"Error pushing field {id}:" f"{response.status_code} {response.text}"
|
||||
)
|
||||
raise ValueError("Error pushing field")
|
||||
else:
|
||||
uuid = response.json()["uuid"]
|
||||
logger.info(f"Field {id} pushed in queue, uuid: {uuid}")
|
||||
self._results[id] = {"uuid": uuid, "status": "pending"}
|
||||
return uuid
|
||||
|
||||
def _pull(self, id: str) -> str:
|
||||
self._pull_queue()
|
||||
result = self._results.get(id, None)
|
||||
if not result:
|
||||
logger.info(f"{id} not in queue")
|
||||
return ""
|
||||
elif result["status"] == "pending":
|
||||
logger.info(f'Waiting for {result["uuid"]} to be processed')
|
||||
return ""
|
||||
else:
|
||||
return result["data"]
|
||||
|
||||
def _pull_queue(self) -> None:
|
||||
try:
|
||||
from nucliadb_protos.writer_pb2 import BrokerMessage
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"nucliadb-protos is not installed. "
|
||||
"Run `pip install nucliadb-protos` to install."
|
||||
) from e
|
||||
try:
|
||||
from google.protobuf.json_format import MessageToJson
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import google.protobuf, please install with "
|
||||
"`pip install protobuf`."
|
||||
) from e
|
||||
|
||||
res = requests.get(
|
||||
self._config["BACKEND"] + "/processing/pull",
|
||||
headers={
|
||||
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
|
||||
},
|
||||
).json()
|
||||
if res["status"] == "empty":
|
||||
logger.info("Queue empty")
|
||||
elif res["status"] == "ok":
|
||||
payload = res["payload"]
|
||||
pb = BrokerMessage()
|
||||
pb.ParseFromString(base64.b64decode(payload))
|
||||
uuid = pb.uuid
|
||||
logger.info(f"Pulled {uuid} from queue")
|
||||
matching_id = self._find_matching_id(uuid)
|
||||
if not matching_id:
|
||||
logger.info(f"No matching id for {uuid}")
|
||||
else:
|
||||
self._results[matching_id]["status"] = "done"
|
||||
data = MessageToJson(
|
||||
pb,
|
||||
preserving_proto_field_name=True,
|
||||
including_default_value_fields=True,
|
||||
)
|
||||
self._results[matching_id]["data"] = data
|
||||
|
||||
def _find_matching_id(self, uuid: str) -> Union[str, None]:
|
||||
for id, result in self._results.items():
|
||||
if result["uuid"] == uuid:
|
||||
return id
|
||||
return None
|
||||
|
||||
def _check_params(self, path: Optional[str], text: Optional[str]) -> None:
|
||||
if not path and not text:
|
||||
raise ValueError("File path or text is required")
|
||||
if path and text:
|
||||
raise ValueError("Cannot process both file and text on a single run")
|
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
from langchain.document_loaders.nuclia import NucliaLoader
|
||||
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
|
||||
|
||||
|
||||
def fakerun(**args: Any) -> Any:
|
||||
def run(self: Any, **args: Any) -> str:
|
||||
data = {
|
||||
"extracted_text": [{"body": {"text": "Hello World"}}],
|
||||
"file_extracted_data": [{"language": "en"}],
|
||||
"field_metadata": [
|
||||
{
|
||||
"metadata": {
|
||||
"metadata": {
|
||||
"paragraphs": [
|
||||
{"end": 66, "sentences": [{"start": 1, "end": 67}]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
return json.dumps(data)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"NUCLIA_NUA_KEY": "_a_key_"})
|
||||
def test_nuclia_loader() -> None:
|
||||
with mock.patch(
|
||||
"langchain.tools.nuclia.tool.NucliaUnderstandingAPI._run", new_callable=fakerun
|
||||
):
|
||||
nua = NucliaUnderstandingAPI(enable_ml=False)
|
||||
loader = NucliaLoader("/whatever/file.mp3", nua)
|
||||
docs = loader.load()
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "Hello World"
|
||||
assert docs[0].metadata["file"]["language"] == "en"
|
||||
assert (
|
||||
len(docs[0].metadata["metadata"]["metadata"]["metadata"]["paragraphs"]) == 1
|
||||
)
|
@@ -0,0 +1,62 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.document_transformers.nuclia_text_transform import NucliaTextTransformer
|
||||
from langchain.schema.document import Document
|
||||
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
|
||||
|
||||
|
||||
def fakerun(**args: Any) -> Any:
|
||||
async def run(self: Any, **args: Any) -> str:
|
||||
await asyncio.sleep(0.1)
|
||||
data = {
|
||||
"extracted_text": [{"body": {"text": "Hello World"}}],
|
||||
"file_extracted_data": [{"language": "en"}],
|
||||
"field_metadata": [
|
||||
{
|
||||
"metadata": {
|
||||
"metadata": {
|
||||
"paragraphs": [
|
||||
{"end": 66, "sentences": [{"start": 1, "end": 67}]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
return json.dumps(data)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nuclia_loader() -> None:
|
||||
with mock.patch(
|
||||
"langchain.tools.nuclia.tool.NucliaUnderstandingAPI._arun", new_callable=fakerun
|
||||
):
|
||||
with mock.patch("os.environ.get", return_value="_a_key_"):
|
||||
nua = NucliaUnderstandingAPI(enable_ml=False)
|
||||
documents = [
|
||||
Document(page_content="Hello, my name is Alice", metadata={}),
|
||||
Document(page_content="Hello, my name is Bob", metadata={}),
|
||||
]
|
||||
nuclia_transformer = NucliaTextTransformer(nua)
|
||||
transformed_documents = await nuclia_transformer.atransform_documents(
|
||||
documents
|
||||
)
|
||||
assert len(transformed_documents) == 2
|
||||
assert (
|
||||
transformed_documents[0].metadata["nuclia"]["file"]["language"] == "en"
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
transformed_documents[1].metadata["nuclia"]["metadata"]["metadata"][
|
||||
"metadata"
|
||||
]["paragraphs"]
|
||||
)
|
||||
== 1
|
||||
)
|
@@ -0,0 +1,110 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
|
||||
|
||||
README_PATH = Path(__file__).parents[4] / "README.md"
|
||||
|
||||
|
||||
class FakeUploadResponse:
|
||||
status_code = 200
|
||||
text = "fake_uuid"
|
||||
|
||||
|
||||
class FakePushResponse:
|
||||
status_code = 200
|
||||
|
||||
def json(self) -> Any:
|
||||
return {"uuid": "fake_uuid"}
|
||||
|
||||
|
||||
class FakePullResponse:
|
||||
status_code = 200
|
||||
|
||||
def json(self) -> Any:
|
||||
return {
|
||||
"status": "ok",
|
||||
"payload": base64.b64encode(bytes('{"some": "data"}}', "utf-8")),
|
||||
}
|
||||
|
||||
|
||||
def FakeParseFromString(**args: Any) -> Any:
|
||||
def ParseFromString(self: Any, data: str) -> None:
|
||||
self.uuid = "fake_uuid"
|
||||
|
||||
return ParseFromString
|
||||
|
||||
|
||||
def fakepost(**kwargs: Any) -> Any:
|
||||
def fn(url: str, **kwargs: Any) -> Any:
|
||||
if url.endswith("/processing/upload"):
|
||||
return FakeUploadResponse()
|
||||
elif url.endswith("/processing/push"):
|
||||
return FakePushResponse()
|
||||
else:
|
||||
raise Exception("Invalid POST URL")
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def fakeget(**kwargs: Any) -> Any:
|
||||
def fn(url: str, **kwargs: Any) -> Any:
|
||||
if url.endswith("/processing/pull"):
|
||||
return FakePullResponse()
|
||||
else:
|
||||
raise Exception("Invalid GET URL")
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"NUCLIA_NUA_KEY": "_a_key_"})
|
||||
@pytest.mark.requires("nucliadb_protos")
|
||||
def test_nuclia_tool() -> None:
|
||||
with mock.patch(
|
||||
"nucliadb_protos.writer_pb2.BrokerMessage.ParseFromString",
|
||||
new_callable=FakeParseFromString,
|
||||
):
|
||||
with mock.patch("requests.post", new_callable=fakepost):
|
||||
with mock.patch("requests.get", new_callable=fakeget):
|
||||
nua = NucliaUnderstandingAPI(enable_ml=False)
|
||||
uuid = nua.run(
|
||||
{
|
||||
"action": "push",
|
||||
"id": "1",
|
||||
"path": str(README_PATH),
|
||||
"text": None,
|
||||
}
|
||||
)
|
||||
assert uuid == "fake_uuid"
|
||||
data = nua.run(
|
||||
{"action": "pull", "id": "1", "path": None, "text": None}
|
||||
)
|
||||
assert json.loads(data)["uuid"] == "fake_uuid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("nucliadb_protos")
|
||||
async def test_async_call() -> None:
|
||||
with mock.patch(
|
||||
"nucliadb_protos.writer_pb2.BrokerMessage.ParseFromString",
|
||||
new_callable=FakeParseFromString,
|
||||
):
|
||||
with mock.patch("requests.post", new_callable=fakepost):
|
||||
with mock.patch("requests.get", new_callable=fakeget):
|
||||
with mock.patch("os.environ.get", return_value="_a_key_"):
|
||||
nua = NucliaUnderstandingAPI(enable_ml=False)
|
||||
data = await nua.arun(
|
||||
{
|
||||
"action": "push",
|
||||
"id": "1",
|
||||
"path": str(README_PATH),
|
||||
"text": None,
|
||||
}
|
||||
)
|
||||
assert json.loads(data)["uuid"] == "fake_uuid"
|
Reference in New Issue
Block a user