This commit is contained in:
Eugene Yurtsev
2024-02-22 11:26:29 -05:00
parent e81c7c6673
commit ba8e0101e1
3 changed files with 25 additions and 14 deletions

View File

@@ -117,7 +117,9 @@ from langchain_community.document_loaders.html_bs import BSHTMLLoader
from langchain_community.document_loaders.hugging_face_dataset import (
HuggingFaceDatasetLoader,
)
from langchain_community.document_loaders.hugging_face_model import HuggingFaceModelLoader
from langchain_community.document_loaders.hugging_face_model import (
HuggingFaceModelLoader,
)
from langchain_community.document_loaders.ifixit import IFixitLoader
from langchain_community.document_loaders.image import UnstructuredImageLoader
from langchain_community.document_loaders.image_captions import ImageCaptionLoader

View File

@@ -1,14 +1,16 @@
import requests
from typing import Iterator, List, Optional
import requests
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
class HuggingFaceModelLoader(BaseLoader):
"""
Load model information from `Hugging Face Hub`, including README content.
This loader interfaces with the Hugging Face Models API to fetch and load model metadata and README files.
This loader interfaces with the Hugging Face Models API to fetch and load model metadata and README files.
The API allows you to search and filter models based on specific criteria such as model tags, authors, and more.
API URL: https://huggingface.co/api/models
@@ -72,7 +74,10 @@ class HuggingFaceModelLoader(BaseLoader):
def fetch_models(self) -> List[dict]:
"""Fetch model information from Hugging Face Hub."""
response = requests.get(self.BASE_URL, params={k: v for k, v in self.params.items() if v is not None})
response = requests.get(
self.BASE_URL,
params={k: v for k, v in self.params.items() if v is not None},
)
response.raise_for_status()
return response.json()
@@ -93,7 +98,7 @@ class HuggingFaceModelLoader(BaseLoader):
for model in models:
model_id = model.get("modelId", "")
readme_content = self.fetch_readme_content(model_id)
yield Document(
page_content=readme_content,
metadata=model,
@@ -102,6 +107,3 @@ class HuggingFaceModelLoader(BaseLoader):
def load(self) -> List[Document]:
"""Load model information, including README content."""
return list(self.lazy_load())

View File

@@ -1,7 +1,10 @@
import json
import pytest
import responses
from langchain_community.document_loaders import HuggingFaceModelLoader
import json
# Mocked model data to simulate an API response
MOCKED_MODELS_RESPONSE = [
{
@@ -23,12 +26,12 @@ MOCKED_MODELS_RESPONSE = [
"autotrain_compatible",
"endpoints_compatible",
"has_space",
"region:us"
"region:us",
],
"pipeline_tag": "text-generation",
"library_name": "transformers",
"createdAt": "2023-12-13T21:19:59.000Z",
"modelId": "microsoft/phi-2"
"modelId": "microsoft/phi-2",
},
# Add additional models as needed
]
@@ -39,26 +42,30 @@ MOCKED_README_CONTENT = {
"openai/gpt-3": "README content for openai/gpt-3",
}
def response_callback(request):
if "/api/models" in request.url:
return (200, {}, json.dumps(MOCKED_MODELS_RESPONSE))
elif "README.md" in request.url:
model_id = request.url.split('/')[3] + '/' + request.url.split('/')[4] # Extract model_id
model_id = (
request.url.split("/")[3] + "/" + request.url.split("/")[4]
) # Extract model_id
content = MOCKED_README_CONTENT.get(model_id, "")
return (200, {}, content)
return (404, {}, "Not Found")
@responses.activate
def test_load_models_with_readme():
"""Tests loading models along with their README content."""
responses.add_callback(
responses.GET,
responses.GET,
"https://huggingface.co/api/models",
callback=response_callback,
content_type="application/json",
)
responses.add_callback(
responses.GET,
responses.GET,
"https://huggingface.co/microsoft/phi-2/raw/main/README.md", # Use a regex or update this placeholder
callback=response_callback,
content_type="text/plain",