diff --git a/libs/langchain/langchain/document_loaders/geodataframe.py b/libs/langchain/langchain/document_loaders/geodataframe.py index 70f5fa626e0..e621d2fa990 100644 --- a/libs/langchain/langchain/document_loaders/geodataframe.py +++ b/libs/langchain/langchain/document_loaders/geodataframe.py @@ -29,19 +29,43 @@ class GeoDataFrameLoader(BaseLoader): f"Expected data_frame to be a gpd.GeoDataFrame, got {type(data_frame)}" ) + if page_content_column not in data_frame.columns: + raise ValueError( + f"Expected data_frame to have a column named {page_content_column}" + ) + + if not isinstance(data_frame[page_content_column].iloc[0], gpd.GeoSeries): + raise ValueError( + f"Expected data_frame[{page_content_column}] to be a GeoSeries" + ) + self.data_frame = data_frame self.page_content_column = page_content_column def lazy_load(self) -> Iterator[Document]: """Lazy load records from dataframe.""" + # assumes all geometries in GeoSeries are same CRS and Geom Type + crs_str = self.data_frame.crs.to_string() if self.data_frame.crs else None + geometry_type = self.data_frame.geometry.geom_type.iloc[0] + for _, row in self.data_frame.iterrows(): - text = row[self.page_content_column] + geom = row[self.page_content_column] + + xmin, ymin, xmax, ymax = geom.bounds + metadata = row.to_dict() + metadata["crs"] = crs_str + metadata["geometry_type"] = geometry_type + metadata["xmin"] = xmin + metadata["ymin"] = ymin + metadata["xmax"] = xmax + metadata["ymax"] = ymax + metadata.pop(self.page_content_column) - # Enforce str since shapely Point objects - # geometry type used in GeoPandas) are not strings - yield Document(page_content=str(text), metadata=metadata) + + # using WKT instead of str() to help GIS system interoperability + yield Document(page_content=geom.wkt, metadata=metadata) def load(self) -> List[Document]: """Load full dataframe.""" diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py b/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py index 2dbf2777478..4b0680b0f11 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py @@ -17,6 +17,7 @@ else: def sample_gdf() -> GeoDataFrame: import geopandas + # TODO: geopandas.datasets will be deprecated in 1.0 path_to_data = geopandas.datasets.get_path("nybb") gdf = geopandas.read_file(path_to_data) gdf["area"] = gdf.area