add loading chains from hub (#757)

This commit is contained in:
Harrison Chase 2023-01-26 21:11:31 -08:00 committed by GitHub
parent 1b89a438cf
commit f273c50d62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,11 @@
"""Functionality for loading chains.""" """Functionality for loading chains."""
import json import json
import os
import tempfile
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import requests
import yaml import yaml
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -10,6 +13,8 @@ from langchain.chains.llm import LLMChain
from langchain.llms.loading import load_llm, load_llm_from_config from langchain.llms.loading import load_llm, load_llm_from_config
from langchain.prompts.loading import load_prompt, load_prompt_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
def _load_llm_chain(config: dict) -> LLMChain: def _load_llm_chain(config: dict) -> LLMChain:
"""Load LLM chain from config dict.""" """Load LLM chain from config dict."""
@ -48,7 +53,16 @@ def load_chain_from_config(config: dict) -> Chain:
return chain_loader(config) return chain_loader(config)
def load_chain(file: Union[str, Path]) -> Chain: def load_chain(path: Union[str, Path]) -> Chain:
"""Unified method for loading a chain from LangChainHub or local fs."""
if isinstance(path, str) and path.startswith("lc://chains"):
path = os.path.relpath(path, "lc://chains/")
return _load_from_hub(path)
else:
return _load_chain_from_file(path)
def _load_chain_from_file(file: Union[str, Path]) -> Chain:
"""Load chain from file.""" """Load chain from file."""
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): if isinstance(file, str):
@ -66,3 +80,19 @@ def load_chain(file: Union[str, Path]) -> Chain:
raise ValueError("File type must be json or yaml") raise ValueError("File type must be json or yaml")
# Load the chain from the config now. # Load the chain from the config now.
return load_chain_from_config(config) return load_chain_from_config(config)
def _load_from_hub(path: str) -> Chain:
"""Load chain from hub."""
suffix = path.split(".")[-1]
if suffix not in {"json", "yaml"}:
raise ValueError("Unsupported file type.")
full_url = URL_BASE + path
r = requests.get(full_url)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = tmpdirname + "/chain." + suffix
with open(file, "wb") as f:
f.write(r.content)
return _load_chain_from_file(file)