diff --git a/private_gpt/components/ingest/ingest_component.py b/private_gpt/components/ingest/ingest_component.py index 5ed03959..122b8957 100644 --- a/private_gpt/components/ingest/ingest_component.py +++ b/private_gpt/components/ingest/ingest_component.py @@ -40,11 +40,11 @@ class BaseIngestComponent(abc.ABC): self.transformations = transformations @abc.abstractmethod - def ingest(self, file_name: str, file_data: Path) -> list[Document]: + def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]: pass @abc.abstractmethod - def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: + def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]: pass @abc.abstractmethod @@ -117,20 +117,20 @@ class SimpleIngestComponent(BaseIngestComponentWithIndex): ) -> None: super().__init__(storage_context, embed_model, transformations, *args, **kwargs) - def ingest(self, file_name: str, file_data: Path) -> list[Document]: + def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]: logger.info("Ingesting file_name=%s", file_name) - documents = IngestionHelper.transform_file_into_documents(file_name, file_data) + documents = IngestionHelper.transform_file_into_documents(file_name, file_data, file_metadata) logger.info( "Transformed file=%s into count=%s documents", file_name, len(documents) ) logger.debug("Saving the documents in the index and doc store") return self._save_docs(documents) - def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: + def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]: saved_documents = [] for file_name, file_data in files: documents = IngestionHelper.transform_file_into_documents( - file_name, file_data + file_name, file_data, metadata ) saved_documents.extend(self._save_docs(documents)) return saved_documents @@ -175,20 +175,20 @@ class BatchIngestComponent(BaseIngestComponentWithIndex): processes=self.count_workers ) - def ingest(self, file_name: str, file_data: Path) -> list[Document]: + def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]: logger.info("Ingesting file_name=%s", file_name) - documents = IngestionHelper.transform_file_into_documents(file_name, file_data) + documents = IngestionHelper.transform_file_into_documents(file_name, file_data, file_metadata) logger.info( "Transformed file=%s into count=%s documents", file_name, len(documents) ) logger.debug("Saving the documents in the index and doc store") return self._save_docs(documents) - def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: + def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]: documents = list( itertools.chain.from_iterable( self._file_to_documents_work_pool.starmap( - IngestionHelper.transform_file_into_documents, files + IngestionHelper.transform_file_into_documents, files, metadata ) ) ) @@ -257,12 +257,12 @@ class ParallelizedIngestComponent(BaseIngestComponentWithIndex): processes=self.count_workers ) - def ingest(self, file_name: str, file_data: Path) -> list[Document]: + def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]: logger.info("Ingesting file_name=%s", file_name) # Running in a single (1) process to release the current # thread, and take a dedicated CPU core for computation documents = self._file_to_documents_work_pool.apply( - IngestionHelper.transform_file_into_documents, (file_name, file_data) + IngestionHelper.transform_file_into_documents, (file_name, file_data, file_metadata) ) logger.info( "Transformed file=%s into count=%s documents", file_name, len(documents) @@ -270,13 +270,13 @@ class ParallelizedIngestComponent(BaseIngestComponentWithIndex): logger.debug("Saving the documents in the index and doc store") return self._save_docs(documents) - def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: + def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]: # Lightweight threads, used for parallelize the # underlying IO calls made in the ingestion documents = list( itertools.chain.from_iterable( - self._ingest_work_pool.starmap(self.ingest, files) + self._ingest_work_pool.starmap(self.ingest, files, metadata) ) ) return documents @@ -459,18 +459,18 @@ class PipelineIngestComponent(BaseIngestComponentWithIndex): self.node_q.put(("flush", None, None, None)) self.node_q.join() - def ingest(self, file_name: str, file_data: Path) -> list[Document]: - documents = IngestionHelper.transform_file_into_documents(file_name, file_data) + def ingest(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[Document]: + documents = IngestionHelper.transform_file_into_documents(file_name, file_data, file_metadata) self.doc_q.put(("process", file_name, documents)) self._flush() return documents - def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: + def bulk_ingest(self, files: list[tuple[str, Path]], metadata : dict | None = None) -> list[Document]: docs = [] for file_name, file_data in eta(files): try: documents = IngestionHelper.transform_file_into_documents( - file_name, file_data + file_name, file_data, metadata ) self.doc_q.put(("process", file_name, documents)) docs.extend(documents) diff --git a/private_gpt/components/ingest/ingest_helper.py b/private_gpt/components/ingest/ingest_helper.py index a1109070..aa841b08 100644 --- a/private_gpt/components/ingest/ingest_helper.py +++ b/private_gpt/components/ingest/ingest_helper.py @@ -69,11 +69,13 @@ class IngestionHelper: @staticmethod def transform_file_into_documents( - file_name: str, file_data: Path + file_name: str, file_data: Path, file_metadata : dict | None = None ) -> list[Document]: documents = IngestionHelper._load_file_to_documents(file_name, file_data) for document in documents: + document.metadata.update(file_metadata or {}) document.metadata["file_name"] = file_name + IngestionHelper._exclude_metadata(documents) return documents diff --git a/private_gpt/server/ingest/ingest_router.py b/private_gpt/server/ingest/ingest_router.py index 56adba46..5e8c844d 100644 --- a/private_gpt/server/ingest/ingest_router.py +++ b/private_gpt/server/ingest/ingest_router.py @@ -1,6 +1,6 @@ -from typing import Literal +from typing import Literal, Dict -from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, Form from pydantic import BaseModel, Field from private_gpt.server.ingest.ingest_service import IngestService @@ -20,6 +20,15 @@ class IngestTextBody(BaseModel): "Chinese martial arts." ] ) + metadata: Dict = Field(None, + examples=[ + { + "title": "Avatar: The Last Airbender", + "author": "Michael Dante DiMartino, Bryan Konietzko", + "year": "2005", + } + ] + ) class IngestResponse(BaseModel): @@ -38,7 +47,7 @@ def ingest(request: Request, file: UploadFile) -> IngestResponse: @ingest_router.post("/ingest/file", tags=["Ingestion"]) -def ingest_file(request: Request, file: UploadFile) -> IngestResponse: +def ingest_file(request: Request, file: UploadFile, metadata: str = Form(None)) -> IngestResponse: """Ingests and processes a file, storing its chunks to be used as context. The context obtained from files is later used in @@ -57,7 +66,9 @@ def ingest_file(request: Request, file: UploadFile) -> IngestResponse: service = request.state.injector.get(IngestService) if file.filename is None: raise HTTPException(400, "No file name provided") - ingested_documents = service.ingest_bin_data(file.filename, file.file) + + metadata_dict = None if metadata is None else eval(metadata) + ingested_documents = service.ingest_bin_data(file.filename, file.file, metadata_dict) return IngestResponse(object="list", model="private-gpt", data=ingested_documents) @@ -77,7 +88,7 @@ def ingest_text(request: Request, body: IngestTextBody) -> IngestResponse: service = request.state.injector.get(IngestService) if len(body.file_name) == 0: raise HTTPException(400, "No file name provided") - ingested_documents = service.ingest_text(body.file_name, body.text) + ingested_documents = service.ingest_text(body.file_name, body.text, body.metadata) return IngestResponse(object="list", model="private-gpt", data=ingested_documents) diff --git a/private_gpt/server/ingest/ingest_service.py b/private_gpt/server/ingest/ingest_service.py index f9ae4728..9082432c 100644 --- a/private_gpt/server/ingest/ingest_service.py +++ b/private_gpt/server/ingest/ingest_service.py @@ -48,7 +48,7 @@ class IngestService: settings=settings(), ) - def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]: + def _ingest_data(self, file_name: str, file_data: AnyStr, file_metadata : dict | None = None) -> list[IngestedDoc]: logger.debug("Got file data of size=%s to ingest", len(file_data)) # llama-index mainly supports reading from files, so # we have to create a tmp file to read for it to work @@ -60,27 +60,27 @@ class IngestService: path_to_tmp.write_bytes(file_data) else: path_to_tmp.write_text(str(file_data)) - return self.ingest_file(file_name, path_to_tmp) + return self.ingest_file(file_name, path_to_tmp, file_metadata) finally: tmp.close() path_to_tmp.unlink() - def ingest_file(self, file_name: str, file_data: Path) -> list[IngestedDoc]: + def ingest_file(self, file_name: str, file_data: Path, file_metadata : dict | None = None) -> list[IngestedDoc]: logger.info("Ingesting file_name=%s", file_name) - documents = self.ingest_component.ingest(file_name, file_data) + documents = self.ingest_component.ingest(file_name, file_data, file_metadata) logger.info("Finished ingestion file_name=%s", file_name) return [IngestedDoc.from_document(document) for document in documents] - def ingest_text(self, file_name: str, text: str) -> list[IngestedDoc]: + def ingest_text(self, file_name: str, text: str, metadata : dict | None = None) -> list[IngestedDoc]: logger.debug("Ingesting text data with file_name=%s", file_name) - return self._ingest_data(file_name, text) + return self._ingest_data(file_name, text, metadata) def ingest_bin_data( - self, file_name: str, raw_file_data: BinaryIO + self, file_name: str, raw_file_data: BinaryIO, file_metadata : dict | None = None ) -> list[IngestedDoc]: logger.debug("Ingesting binary data with file_name=%s", file_name) file_data = raw_file_data.read() - return self._ingest_data(file_name, file_data) + return self._ingest_data(file_name, file_data, file_metadata) def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]: logger.info("Ingesting file_names=%s", [f[0] for f in files]) diff --git a/tests/fixtures/ingest_helper.py b/tests/fixtures/ingest_helper.py index 25515f4e..64d4babc 100644 --- a/tests/fixtures/ingest_helper.py +++ b/tests/fixtures/ingest_helper.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest +import json from fastapi.testclient import TestClient from private_gpt.server.ingest.ingest_router import IngestResponse @@ -17,6 +18,18 @@ class IngestHelper: assert response.status_code == 200 ingest_result = IngestResponse.model_validate(response.json()) return ingest_result + + def ingest_file_with_metadata(self, path: Path, metadata: dict) -> IngestResponse: + files = { + "file": (path.name, path.open("rb")), + "metadata": (None, json.dumps(metadata)) + } + + response = self.test_client.post("/v1/ingest/file", files=files) + + assert response.status_code == 200 + ingest_result = IngestResponse.model_validate(response.json()) + return ingest_result @pytest.fixture() diff --git a/tests/server/ingest/test_ingest_routes.py b/tests/server/ingest/test_ingest_routes.py index 896410a1..15ad74e8 100644 --- a/tests/server/ingest/test_ingest_routes.py +++ b/tests/server/ingest/test_ingest_routes.py @@ -44,3 +44,21 @@ def test_ingest_plain_text(test_client: TestClient) -> None: assert response.status_code == 200 ingest_result = IngestResponse.model_validate(response.json()) assert len(ingest_result.data) == 1 + + +def test_ingest_text_with_metadata(test_client: TestClient): + response = test_client.post( + "/v1/ingest/text", json={"file_name": "file_name", "text": "text", "metadata": {"foo": "bar"}} + ) + assert response.status_code == 200 + ingest_result = IngestResponse.model_validate(response.json()) + assert len(ingest_result.data) == 1 + + assert ingest_result.data[0].doc_metadata == {"file_name" : "file_name", "foo": "bar"} + + +def test_ingest_accepts_txt_files(ingest_helper: IngestHelper) -> None: + path = Path(__file__).parents[0] / "test.txt" + ingest_result = ingest_helper.ingest_file_with_metadata(path, {"foo": "bar"}) + assert len(ingest_result.data) == 1 + assert ingest_result.data[0].doc_metadata == {"file_name": "test.txt", "foo": "bar"} \ No newline at end of file