feat: Add optional metadata param to ingest routes

This commit is contained in:
Nathan Lenas 2024-07-23 08:50:54 +02:00
parent b62669784b
commit d559d54e1a
6 changed files with 76 additions and 32 deletions

View File

@ -40,11 +40,11 @@ class BaseIngestComponent(abc.ABC):
self.transformations = transformations self.transformations = transformations
@abc.abstractmethod @abc.abstractmethod
def ingest(self, file_name: str, file_data: Path) -> list[Document]: def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]:
pass pass
@abc.abstractmethod @abc.abstractmethod
@ -117,20 +117,20 @@ class SimpleIngestComponent(BaseIngestComponentWithIndex):
) -> None: ) -> None:
super().__init__(storage_context, embed_model, transformations, *args, **kwargs) super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
def ingest(self, file_name: str, file_data: Path) -> list[Document]: def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]:
logger.info("Ingesting file_name=%s", file_name) logger.info("Ingesting file_name=%s", file_name)
documents = IngestionHelper.transform_file_into_documents(file_name, file_data) documents = IngestionHelper.transform_file_into_documents(file_name, file_data, file_metadata)
logger.info( logger.info(
"Transformed file=%s into count=%s documents", file_name, len(documents) "Transformed file=%s into count=%s documents", file_name, len(documents)
) )
logger.debug("Saving the documents in the index and doc store") logger.debug("Saving the documents in the index and doc store")
return self._save_docs(documents) return self._save_docs(documents)
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]:
saved_documents = [] saved_documents = []
for file_name, file_data in files: for file_name, file_data in files:
documents = IngestionHelper.transform_file_into_documents( documents = IngestionHelper.transform_file_into_documents(
file_name, file_data file_name, file_data, metadata
) )
saved_documents.extend(self._save_docs(documents)) saved_documents.extend(self._save_docs(documents))
return saved_documents return saved_documents
@ -175,20 +175,20 @@ class BatchIngestComponent(BaseIngestComponentWithIndex):
processes=self.count_workers processes=self.count_workers
) )
def ingest(self, file_name: str, file_data: Path) -> list[Document]: def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]:
logger.info("Ingesting file_name=%s", file_name) logger.info("Ingesting file_name=%s", file_name)
documents = IngestionHelper.transform_file_into_documents(file_name, file_data) documents = IngestionHelper.transform_file_into_documents(file_name, file_data, file_metadata)
logger.info( logger.info(
"Transformed file=%s into count=%s documents", file_name, len(documents) "Transformed file=%s into count=%s documents", file_name, len(documents)
) )
logger.debug("Saving the documents in the index and doc store") logger.debug("Saving the documents in the index and doc store")
return self._save_docs(documents) return self._save_docs(documents)
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]:
documents = list( documents = list(
itertools.chain.from_iterable( itertools.chain.from_iterable(
self._file_to_documents_work_pool.starmap( self._file_to_documents_work_pool.starmap(
IngestionHelper.transform_file_into_documents, files IngestionHelper.transform_file_into_documents, files, metadata
) )
) )
) )
@ -257,12 +257,12 @@ class ParallelizedIngestComponent(BaseIngestComponentWithIndex):
processes=self.count_workers processes=self.count_workers
) )
def ingest(self, file_name: str, file_data: Path) -> list[Document]: def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]:
logger.info("Ingesting file_name=%s", file_name) logger.info("Ingesting file_name=%s", file_name)
# Running in a single (1) process to release the current # Running in a single (1) process to release the current
# thread, and take a dedicated CPU core for computation # thread, and take a dedicated CPU core for computation
documents = self._file_to_documents_work_pool.apply( documents = self._file_to_documents_work_pool.apply(
IngestionHelper.transform_file_into_documents, (file_name, file_data) IngestionHelper.transform_file_into_documents, (file_name, file_data, file_metadata)
) )
logger.info( logger.info(
"Transformed file=%s into count=%s documents", file_name, len(documents) "Transformed file=%s into count=%s documents", file_name, len(documents)
@ -270,13 +270,13 @@ class ParallelizedIngestComponent(BaseIngestComponentWithIndex):
logger.debug("Saving the documents in the index and doc store") logger.debug("Saving the documents in the index and doc store")
return self._save_docs(documents) return self._save_docs(documents)
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]:
# Lightweight threads, used for parallelize the # Lightweight threads, used for parallelize the
# underlying IO calls made in the ingestion # underlying IO calls made in the ingestion
documents = list( documents = list(
itertools.chain.from_iterable( itertools.chain.from_iterable(
self._ingest_work_pool.starmap(self.ingest, files) self._ingest_work_pool.starmap(self.ingest, files, metadata)
) )
) )
return documents return documents
@ -459,18 +459,18 @@ class PipelineIngestComponent(BaseIngestComponentWithIndex):
self.node_q.put(("flush", None, None, None)) self.node_q.put(("flush", None, None, None))
self.node_q.join() self.node_q.join()
def ingest(self, file_name: str, file_data: Path) -> list[Document]: def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]:
documents = IngestionHelper.transform_file_into_documents(file_name, file_data) documents = IngestionHelper.transform_file_into_documents(file_name, file_data, file_metadata)
self.doc_q.put(("process", file_name, documents)) self.doc_q.put(("process", file_name, documents))
self._flush() self._flush()
return documents return documents
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]:
docs = [] docs = []
for file_name, file_data in eta(files): for file_name, file_data in eta(files):
try: try:
documents = IngestionHelper.transform_file_into_documents( documents = IngestionHelper.transform_file_into_documents(
file_name, file_data file_name, file_data, metadata
) )
self.doc_q.put(("process", file_name, documents)) self.doc_q.put(("process", file_name, documents))
docs.extend(documents) docs.extend(documents)

View File

@ -69,11 +69,13 @@ class IngestionHelper:
@staticmethod @staticmethod
def transform_file_into_documents( def transform_file_into_documents(
file_name: str, file_data: Path file_name: str, file_data: Path, file_metadata : dict | None = None
) -> list[Document]: ) -> list[Document]:
documents = IngestionHelper._load_file_to_documents(file_name, file_data) documents = IngestionHelper._load_file_to_documents(file_name, file_data)
for document in documents: for document in documents:
document.metadata.update(file_metadata or {})
document.metadata["file_name"] = file_name document.metadata["file_name"] = file_name
IngestionHelper._exclude_metadata(documents) IngestionHelper._exclude_metadata(documents)
return documents return documents

View File

@ -1,6 +1,6 @@
from typing import Literal from typing import Literal, Dict
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, Form
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.server.ingest.ingest_service import IngestService
@ -20,6 +20,15 @@ class IngestTextBody(BaseModel):
"Chinese martial arts." "Chinese martial arts."
] ]
) )
metadata: Dict = Field(None,
examples=[
{
"title": "Avatar: The Last Airbender",
"author": "Michael Dante DiMartino, Bryan Konietzko",
"year": "2005",
}
]
)
class IngestResponse(BaseModel): class IngestResponse(BaseModel):
@ -38,7 +47,7 @@ def ingest(request: Request, file: UploadFile) -> IngestResponse:
@ingest_router.post("/ingest/file", tags=["Ingestion"]) @ingest_router.post("/ingest/file", tags=["Ingestion"])
def ingest_file(request: Request, file: UploadFile) -> IngestResponse: def ingest_file(request: Request, file: UploadFile, metadata: str = Form(None)) -> IngestResponse:
"""Ingests and processes a file, storing its chunks to be used as context. """Ingests and processes a file, storing its chunks to be used as context.
The context obtained from files is later used in The context obtained from files is later used in
@ -57,7 +66,9 @@ def ingest_file(request: Request, file: UploadFile) -> IngestResponse:
service = request.state.injector.get(IngestService) service = request.state.injector.get(IngestService)
if file.filename is None: if file.filename is None:
raise HTTPException(400, "No file name provided") raise HTTPException(400, "No file name provided")
ingested_documents = service.ingest_bin_data(file.filename, file.file)
metadata_dict = None if metadata is None else eval(metadata)
ingested_documents = service.ingest_bin_data(file.filename, file.file, metadata_dict)
return IngestResponse(object="list", model="private-gpt", data=ingested_documents) return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
@ -77,7 +88,7 @@ def ingest_text(request: Request, body: IngestTextBody) -> IngestResponse:
service = request.state.injector.get(IngestService) service = request.state.injector.get(IngestService)
if len(body.file_name) == 0: if len(body.file_name) == 0:
raise HTTPException(400, "No file name provided") raise HTTPException(400, "No file name provided")
ingested_documents = service.ingest_text(body.file_name, body.text) ingested_documents = service.ingest_text(body.file_name, body.text, body.metadata)
return IngestResponse(object="list", model="private-gpt", data=ingested_documents) return IngestResponse(object="list", model="private-gpt", data=ingested_documents)

View File

@ -48,7 +48,7 @@ class IngestService:
settings=settings(), settings=settings(),
) )
def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]: def _ingest_data(self, file_name: str, file_data: AnyStr, file_metadata : dict | None = None) -> list[IngestedDoc]:
logger.debug("Got file data of size=%s to ingest", len(file_data)) logger.debug("Got file data of size=%s to ingest", len(file_data))
# llama-index mainly supports reading from files, so # llama-index mainly supports reading from files, so
# we have to create a tmp file to read for it to work # we have to create a tmp file to read for it to work
@ -60,27 +60,27 @@ class IngestService:
path_to_tmp.write_bytes(file_data) path_to_tmp.write_bytes(file_data)
else: else:
path_to_tmp.write_text(str(file_data)) path_to_tmp.write_text(str(file_data))
return self.ingest_file(file_name, path_to_tmp) return self.ingest_file(file_name, path_to_tmp, file_metadata)
finally: finally:
tmp.close() tmp.close()
path_to_tmp.unlink() path_to_tmp.unlink()
def ingest_file(self, file_name: str, file_data: Path) -> list[IngestedDoc]: def ingest_file(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[IngestedDoc]:
logger.info("Ingesting file_name=%s", file_name) logger.info("Ingesting file_name=%s", file_name)
documents = self.ingest_component.ingest(file_name, file_data) documents = self.ingest_component.ingest(file_name, file_data, file_metadata)
logger.info("Finished ingestion file_name=%s", file_name) logger.info("Finished ingestion file_name=%s", file_name)
return [IngestedDoc.from_document(document) for document in documents] return [IngestedDoc.from_document(document) for document in documents]
def ingest_text(self, file_name: str, text: str) -> list[IngestedDoc]: def ingest_text(self, file_name: str, text: str, metadata : dict | None = None) -> list[IngestedDoc]:
logger.debug("Ingesting text data with file_name=%s", file_name) logger.debug("Ingesting text data with file_name=%s", file_name)
return self._ingest_data(file_name, text) return self._ingest_data(file_name, text, metadata)
def ingest_bin_data( def ingest_bin_data(
self, file_name: str, raw_file_data: BinaryIO self, file_name: str, raw_file_data: BinaryIO, file_metadata : dict | None = None
) -> list[IngestedDoc]: ) -> list[IngestedDoc]:
logger.debug("Ingesting binary data with file_name=%s", file_name) logger.debug("Ingesting binary data with file_name=%s", file_name)
file_data = raw_file_data.read() file_data = raw_file_data.read()
return self._ingest_data(file_name, file_data) return self._ingest_data(file_name, file_data, file_metadata)
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]: def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]:
logger.info("Ingesting file_names=%s", [f[0] for f in files]) logger.info("Ingesting file_names=%s", [f[0] for f in files])

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import pytest import pytest
import json
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from private_gpt.server.ingest.ingest_router import IngestResponse from private_gpt.server.ingest.ingest_router import IngestResponse
@ -17,6 +18,18 @@ class IngestHelper:
assert response.status_code == 200 assert response.status_code == 200
ingest_result = IngestResponse.model_validate(response.json()) ingest_result = IngestResponse.model_validate(response.json())
return ingest_result return ingest_result
def ingest_file_with_metadata(self, path: Path, metadata: dict) -> IngestResponse:
files = {
"file": (path.name, path.open("rb")),
"metadata": (None, json.dumps(metadata))
}
response = self.test_client.post("/v1/ingest/file", files=files)
assert response.status_code == 200
ingest_result = IngestResponse.model_validate(response.json())
return ingest_result
@pytest.fixture() @pytest.fixture()

View File

@ -44,3 +44,21 @@ def test_ingest_plain_text(test_client: TestClient) -> None:
assert response.status_code == 200 assert response.status_code == 200
ingest_result = IngestResponse.model_validate(response.json()) ingest_result = IngestResponse.model_validate(response.json())
assert len(ingest_result.data) == 1 assert len(ingest_result.data) == 1
def test_ingest_text_with_metadata(test_client: TestClient):
response = test_client.post(
"/v1/ingest/text", json={"file_name": "file_name", "text": "text", "metadata": {"foo": "bar"}}
)
assert response.status_code == 200
ingest_result = IngestResponse.model_validate(response.json())
assert len(ingest_result.data) == 1
assert ingest_result.data[0].doc_metadata == {"file_name" : "file_name", "foo": "bar"}
def test_ingest_accepts_txt_files(ingest_helper: IngestHelper) -> None:
path = Path(__file__).parents[0] / "test.txt"
ingest_result = ingest_helper.ingest_file_with_metadata(path, {"foo": "bar"})
assert len(ingest_result.data) == 1
assert ingest_result.data[0].doc_metadata == {"file_name": "test.txt", "foo": "bar"}