From f273c50d620604cd8eb36c492f462a523b6bce6a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 26 Jan 2023 21:11:31 -0800 Subject: [PATCH] add loading chains from hub (#757) --- langchain/chains/loading.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index f36321a3fb4..60a2e16bad8 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -1,8 +1,11 @@ """Functionality for loading chains.""" import json +import os +import tempfile from pathlib import Path from typing import Union +import requests import yaml from langchain.chains.base import Chain @@ -10,6 +13,8 @@ from langchain.chains.llm import LLMChain from langchain.llms.loading import load_llm, load_llm_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config +URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/" + def _load_llm_chain(config: dict) -> LLMChain: """Load LLM chain from config dict.""" @@ -48,7 +53,16 @@ def load_chain_from_config(config: dict) -> Chain: return chain_loader(config) -def load_chain(file: Union[str, Path]) -> Chain: +def load_chain(path: Union[str, Path]) -> Chain: + """Unified method for loading a chain from LangChainHub or local fs.""" + if isinstance(path, str) and path.startswith("lc://chains"): + path = os.path.relpath(path, "lc://chains/") + return _load_from_hub(path) + else: + return _load_chain_from_file(path) + + +def _load_chain_from_file(file: Union[str, Path]) -> Chain: """Load chain from file.""" # Convert file to Path object. if isinstance(file, str): @@ -66,3 +80,19 @@ def load_chain(file: Union[str, Path]) -> Chain: raise ValueError("File type must be json or yaml") # Load the chain from the config now. return load_chain_from_config(config) + + +def _load_from_hub(path: str) -> Chain: + """Load chain from hub.""" + suffix = path.split(".")[-1] + if suffix not in {"json", "yaml"}: + raise ValueError("Unsupported file type.") + full_url = URL_BASE + path + r = requests.get(full_url) + if r.status_code != 200: + raise ValueError(f"Could not find file at {full_url}") + with tempfile.TemporaryDirectory() as tmpdirname: + file = tmpdirname + "/chain." + suffix + with open(file, "wb") as f: + f.write(r.content) + return _load_chain_from_file(file)