Bagatur/revert revert nuclia (#8833)

This commit is contained in:
Bagatur
2023-08-06 11:24:36 -07:00
committed by GitHub
parent 2f309a4ce6
commit d7b613a293
12 changed files with 951 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
"""Extract text from any file type."""
import json
import uuid
from typing import List
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
class NucliaLoader(BaseLoader):
"""Extract text from any file type."""
def __init__(self, path: str, nuclia_tool: NucliaUnderstandingAPI):
self.nua = nuclia_tool
self.id = str(uuid.uuid4())
self.nua.run({"action": "push", "id": self.id, "path": path, "text": None})
def load(self) -> List[Document]:
"""Load documents."""
data = self.nua.run(
{"action": "pull", "id": self.id, "path": None, "text": None}
)
if not data:
return []
obj = json.loads(data)
text = obj["extracted_text"][0]["body"]["text"]
print(text)
metadata = {
"file": obj["file_extracted_data"][0],
"metadata": obj["field_metadata"][0],
}
return [Document(page_content=text, metadata=metadata)]

View File

@@ -27,6 +27,7 @@ from langchain.document_transformers.embeddings_redundant_filter import (
)
from langchain.document_transformers.html2text import Html2TextTransformer
from langchain.document_transformers.long_context_reorder import LongContextReorder
from langchain.document_transformers.nuclia_text_transform import NucliaTextTransformer
from langchain.document_transformers.openai_functions import OpenAIMetadataTagger
__all__ = [
@@ -37,6 +38,7 @@ __all__ = [
"EmbeddingsRedundantFilter",
"get_stateful_documents",
"LongContextReorder",
"NucliaTextTransformer",
"OpenAIMetadataTagger",
"Html2TextTransformer",
]

View File

@@ -0,0 +1,47 @@
import asyncio
import json
import uuid
from typing import Any, Sequence
from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
class NucliaTextTransformer(BaseDocumentTransformer):
"""
The Nuclia Understanding API splits into paragraphs and sentences,
identifies entities, provides a summary of the text and generates
embeddings for all the sentences.
"""
def __init__(self, nua: NucliaUnderstandingAPI):
self.nua = nua
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
tasks = [
self.nua.arun(
{
"action": "push",
"id": str(uuid.uuid4()),
"text": doc.page_content,
"path": None,
}
)
for doc in documents
]
results = await asyncio.gather(*tasks)
for doc, result in zip(documents, results):
obj = json.loads(result)
metadata = {
"file": obj["file_extracted_data"][0],
"metadata": obj["field_metadata"][0],
}
doc.metadata["nuclia"] = metadata
return documents

View File

@@ -0,0 +1,3 @@
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
__all__ = ["NucliaUnderstandingAPI"]

View File

@@ -0,0 +1,229 @@
"""Tool for the Nuclia Understanding API.
Installation:
```bash
pip install --upgrade protobuf
pip install nucliadb-protos
```
"""
import asyncio
import base64
import logging
import mimetypes
import os
from typing import Any, Dict, Optional, Type, Union
import requests
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.base import BaseTool
logger = logging.getLogger(__name__)
class NUASchema(BaseModel):
action: str = Field(
...,
description="Action to perform. Either `push` or `pull`.",
)
id: str = Field(
...,
description="ID of the file to push or pull.",
)
path: Optional[str] = Field(
...,
description="Path to the file to push (needed only for `push` action).",
)
text: Optional[str] = Field(
...,
description="Text content to process (needed only for `push` action).",
)
class NucliaUnderstandingAPI(BaseTool):
"""Tool to process files with the Nuclia Understanding API."""
name = "nuclia_understanding_api"
description = (
"A wrapper around Nuclia Understanding API endpoints. "
"Useful for when you need to extract text from any kind of files. "
)
args_schema: Type[BaseModel] = NUASchema
_results: Dict[str, Any] = {}
_config: Dict[str, Any] = {}
def __init__(self, enable_ml: bool = False) -> None:
zone = os.environ.get("NUCLIA_ZONE", "europe-1")
self._config["BACKEND"] = f"https://{zone}.nuclia.cloud/api/v1"
key = os.environ.get("NUCLIA_NUA_KEY")
if not key:
raise ValueError("NUCLIA_NUA_KEY environment variable not set")
else:
self._config["NUA_KEY"] = key
self._config["enable_ml"] = enable_ml
super().__init__()
def _run(
self,
action: str,
id: str,
path: Optional[str],
text: Optional[str],
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
if action == "push":
self._check_params(path, text)
if path:
return self._pushFile(id, path)
if text:
return self._pushText(id, text)
elif action == "pull":
return self._pull(id)
return ""
async def _arun(
self,
action: str,
id: str,
path: Optional[str] = None,
text: Optional[str] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
self._check_params(path, text)
if path:
self._pushFile(id, path)
if text:
self._pushText(id, text)
data = None
while True:
data = self._pull(id)
if data:
break
await asyncio.sleep(15)
return data
def _pushText(self, id: str, text: str) -> str:
field = {
"textfield": {"text": {"body": text, "format": 0}},
"processing_options": {"ml_text": self._config["enable_ml"]},
}
return self._pushField(id, field)
def _pushFile(self, id: str, content_path: str) -> str:
with open(content_path, "rb") as source_file:
response = requests.post(
self._config["BACKEND"] + "/processing/upload",
headers={
"content-type": mimetypes.guess_type(content_path)[0]
or "application/octet-stream",
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
},
data=source_file.read(),
)
if response.status_code != 200:
logger.info(
f"Error uploading {content_path}: "
f"{response.status_code} {response.text}"
)
return ""
else:
field = {
"filefield": {"file": f"{response.text}"},
"processing_options": {"ml_text": self._config["enable_ml"]},
}
return self._pushField(id, field)
def _pushField(self, id: str, field: Any) -> str:
logger.info(f"Pushing {id} in queue")
response = requests.post(
self._config["BACKEND"] + "/processing/push",
headers={
"content-type": "application/json",
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
},
json=field,
)
if response.status_code != 200:
logger.info(
f"Error pushing field {id}:" f"{response.status_code} {response.text}"
)
raise ValueError("Error pushing field")
else:
uuid = response.json()["uuid"]
logger.info(f"Field {id} pushed in queue, uuid: {uuid}")
self._results[id] = {"uuid": uuid, "status": "pending"}
return uuid
def _pull(self, id: str) -> str:
self._pull_queue()
result = self._results.get(id, None)
if not result:
logger.info(f"{id} not in queue")
return ""
elif result["status"] == "pending":
logger.info(f'Waiting for {result["uuid"]} to be processed')
return ""
else:
return result["data"]
def _pull_queue(self) -> None:
try:
from nucliadb_protos.writer_pb2 import BrokerMessage
except ImportError as e:
raise ImportError(
"nucliadb-protos is not installed. "
"Run `pip install nucliadb-protos` to install."
) from e
try:
from google.protobuf.json_format import MessageToJson
except ImportError as e:
raise ImportError(
"Unable to import google.protobuf, please install with "
"`pip install protobuf`."
) from e
res = requests.get(
self._config["BACKEND"] + "/processing/pull",
headers={
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
},
).json()
if res["status"] == "empty":
logger.info("Queue empty")
elif res["status"] == "ok":
payload = res["payload"]
pb = BrokerMessage()
pb.ParseFromString(base64.b64decode(payload))
uuid = pb.uuid
logger.info(f"Pulled {uuid} from queue")
matching_id = self._find_matching_id(uuid)
if not matching_id:
logger.info(f"No matching id for {uuid}")
else:
self._results[matching_id]["status"] = "done"
data = MessageToJson(
pb,
preserving_proto_field_name=True,
including_default_value_fields=True,
)
self._results[matching_id]["data"] = data
def _find_matching_id(self, uuid: str) -> Union[str, None]:
for id, result in self._results.items():
if result["uuid"] == uuid:
return id
return None
def _check_params(self, path: Optional[str], text: Optional[str]) -> None:
if not path and not text:
raise ValueError("File path or text is required")
if path and text:
raise ValueError("Cannot process both file and text on a single run")

View File

@@ -0,0 +1,45 @@
import json
import os
from typing import Any
from unittest import mock
from langchain.document_loaders.nuclia import NucliaLoader
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
def fakerun(**args: Any) -> Any:
def run(self: Any, **args: Any) -> str:
data = {
"extracted_text": [{"body": {"text": "Hello World"}}],
"file_extracted_data": [{"language": "en"}],
"field_metadata": [
{
"metadata": {
"metadata": {
"paragraphs": [
{"end": 66, "sentences": [{"start": 1, "end": 67}]}
]
}
}
}
],
}
return json.dumps(data)
return run
@mock.patch.dict(os.environ, {"NUCLIA_NUA_KEY": "_a_key_"})
def test_nuclia_loader() -> None:
with mock.patch(
"langchain.tools.nuclia.tool.NucliaUnderstandingAPI._run", new_callable=fakerun
):
nua = NucliaUnderstandingAPI(enable_ml=False)
loader = NucliaLoader("/whatever/file.mp3", nua)
docs = loader.load()
assert len(docs) == 1
assert docs[0].page_content == "Hello World"
assert docs[0].metadata["file"]["language"] == "en"
assert (
len(docs[0].metadata["metadata"]["metadata"]["metadata"]["paragraphs"]) == 1
)

View File

@@ -0,0 +1,62 @@
import asyncio
import json
from typing import Any
from unittest import mock
import pytest
from langchain.document_transformers.nuclia_text_transform import NucliaTextTransformer
from langchain.schema.document import Document
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
def fakerun(**args: Any) -> Any:
async def run(self: Any, **args: Any) -> str:
await asyncio.sleep(0.1)
data = {
"extracted_text": [{"body": {"text": "Hello World"}}],
"file_extracted_data": [{"language": "en"}],
"field_metadata": [
{
"metadata": {
"metadata": {
"paragraphs": [
{"end": 66, "sentences": [{"start": 1, "end": 67}]}
]
}
}
}
],
}
return json.dumps(data)
return run
@pytest.mark.asyncio
async def test_nuclia_loader() -> None:
with mock.patch(
"langchain.tools.nuclia.tool.NucliaUnderstandingAPI._arun", new_callable=fakerun
):
with mock.patch("os.environ.get", return_value="_a_key_"):
nua = NucliaUnderstandingAPI(enable_ml=False)
documents = [
Document(page_content="Hello, my name is Alice", metadata={}),
Document(page_content="Hello, my name is Bob", metadata={}),
]
nuclia_transformer = NucliaTextTransformer(nua)
transformed_documents = await nuclia_transformer.atransform_documents(
documents
)
assert len(transformed_documents) == 2
assert (
transformed_documents[0].metadata["nuclia"]["file"]["language"] == "en"
)
assert (
len(
transformed_documents[1].metadata["nuclia"]["metadata"]["metadata"][
"metadata"
]["paragraphs"]
)
== 1
)

View File

@@ -0,0 +1,110 @@
import base64
import json
import os
from pathlib import Path
from typing import Any
from unittest import mock
import pytest
from langchain.tools.nuclia.tool import NucliaUnderstandingAPI
README_PATH = Path(__file__).parents[4] / "README.md"
class FakeUploadResponse:
status_code = 200
text = "fake_uuid"
class FakePushResponse:
status_code = 200
def json(self) -> Any:
return {"uuid": "fake_uuid"}
class FakePullResponse:
status_code = 200
def json(self) -> Any:
return {
"status": "ok",
"payload": base64.b64encode(bytes('{"some": "data"}}', "utf-8")),
}
def FakeParseFromString(**args: Any) -> Any:
def ParseFromString(self: Any, data: str) -> None:
self.uuid = "fake_uuid"
return ParseFromString
def fakepost(**kwargs: Any) -> Any:
def fn(url: str, **kwargs: Any) -> Any:
if url.endswith("/processing/upload"):
return FakeUploadResponse()
elif url.endswith("/processing/push"):
return FakePushResponse()
else:
raise Exception("Invalid POST URL")
return fn
def fakeget(**kwargs: Any) -> Any:
def fn(url: str, **kwargs: Any) -> Any:
if url.endswith("/processing/pull"):
return FakePullResponse()
else:
raise Exception("Invalid GET URL")
return fn
@mock.patch.dict(os.environ, {"NUCLIA_NUA_KEY": "_a_key_"})
@pytest.mark.requires("nucliadb_protos")
def test_nuclia_tool() -> None:
with mock.patch(
"nucliadb_protos.writer_pb2.BrokerMessage.ParseFromString",
new_callable=FakeParseFromString,
):
with mock.patch("requests.post", new_callable=fakepost):
with mock.patch("requests.get", new_callable=fakeget):
nua = NucliaUnderstandingAPI(enable_ml=False)
uuid = nua.run(
{
"action": "push",
"id": "1",
"path": str(README_PATH),
"text": None,
}
)
assert uuid == "fake_uuid"
data = nua.run(
{"action": "pull", "id": "1", "path": None, "text": None}
)
assert json.loads(data)["uuid"] == "fake_uuid"
@pytest.mark.asyncio
@pytest.mark.requires("nucliadb_protos")
async def test_async_call() -> None:
with mock.patch(
"nucliadb_protos.writer_pb2.BrokerMessage.ParseFromString",
new_callable=FakeParseFromString,
):
with mock.patch("requests.post", new_callable=fakepost):
with mock.patch("requests.get", new_callable=fakeget):
with mock.patch("os.environ.get", return_value="_a_key_"):
nua = NucliaUnderstandingAPI(enable_ml=False)
data = await nua.arun(
{
"action": "push",
"id": "1",
"path": str(README_PATH),
"text": None,
}
)
assert json.loads(data)["uuid"] == "fake_uuid"