Compare commits

...

3 Commits

Author SHA1 Message Date
Harrison Chase
cfa7ee61f4 cr 2023-05-23 20:46:19 -07:00
Harrison Chase
a2a02f74a4 cr 2023-05-23 07:20:42 -07:00
Harrison Chase
9d8b62414d exclude embeddings from serializaiton 2023-05-23 07:12:06 -07:00
7 changed files with 64 additions and 59 deletions

View File

@@ -7,7 +7,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional
import numpy as np
from pydantic import Extra
from pydantic import Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
@@ -23,7 +23,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
Based on https://arxiv.org/abs/2212.10496
"""
base_embeddings: Embeddings
base_embeddings: Embeddings = Field(exclude=True)
llm_chain: LLMChain
class Config:

View File

@@ -30,13 +30,12 @@ class LLMBashChain(Chain):
"""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
prompt: BasePromptTemplate = PROMPT
"""[Deprecated]"""
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
bash_process: BashProcess = Field(
default_factory=BashProcess, exclude=True
) #: :meta private:
class Config:
"""Configuration for this pydantic object."""
@@ -51,9 +50,13 @@ class LLMBashChain(Chain):
"Directly instantiating an LLMBashChain with an llm is deprecated. "
"Please instantiate with llm_chain or using the from_llm class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
llm = values.pop("llm")
if "llm_chain" not in values and llm is not None:
if "prompt" in values:
prompt = values.pop("prompt")
else:
prompt = PROMPT
values["llm_chain"] = LLMChain(llm=llm, prompt=prompt)
return values
@root_validator

View File

@@ -1,11 +1,8 @@
# flake8: noqa
from __future__ import annotations
import re
from typing import List
from langchain.output_parsers.bash import BashOutputParser
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
@@ -25,38 +22,6 @@ That is the format. Begin!
Question: {question}"""
class BashOutputParser(BaseOutputParser):
"""Parser for bash output."""
def parse(self, text: str) -> List[str]:
if "```bash" in text:
return self.get_code_blocks(text)
else:
raise OutputParserException(
f"Failed to parse bash output. Got: {text}",
)
@staticmethod
def get_code_blocks(t: str) -> List[str]:
"""Get multiple code blocks from the LLM result."""
code_blocks: List[str] = []
# Bash markdown code blocks
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
for match in pattern.finditer(t):
matched = match.group(1).strip()
if matched:
code_blocks.extend(
[line for line in matched.split("\n") if line.strip()]
)
return code_blocks
@property
def _type(self) -> str:
return "bash"
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,

View File

@@ -138,19 +138,14 @@ def _load_map_reduce_documents_chain(
def _load_llm_bash_chain(config: dict, **kwargs: Any) -> LLMBashChain:
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
elif "llm_path" in config:
llm = load_llm(config.pop("llm_path"))
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config)
elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm` or `llm_path` must be present.")
if "prompt" in config:
prompt_config = config.pop("prompt")
prompt = load_prompt_from_config(prompt_config)
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
return LLMBashChain(llm=llm, prompt=prompt, **config)
raise ValueError("One of `llm_chain` or `llm_chain_config` must be present.")
return LLMBashChain(llm_chain=llm_chain, **config)
def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import re
from typing import List
from langchain.schema import BaseOutputParser, OutputParserException
class BashOutputParser(BaseOutputParser):
"""Parser for bash output."""
def parse(self, text: str) -> List[str]:
if "```bash" in text:
return self.get_code_blocks(text)
else:
raise OutputParserException(
f"Failed to parse bash output. Got: {text}",
)
@staticmethod
def get_code_blocks(t: str) -> List[str]:
"""Get multiple code blocks from the LLM result."""
code_blocks: List[str] = []
# Bash markdown code blocks
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
for match in pattern.finditer(t):
matched = match.group(1).strip()
if matched:
code_blocks.extend(
[line for line in matched.split("\n") if line.strip()]
)
return code_blocks
@property
def _type(self) -> str:
return "bash"

View File

@@ -7,10 +7,12 @@ from typing import Union
import yaml
from langchain.output_parsers.bash import BashOutputParser
from langchain.output_parsers.regex import RegexParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
@@ -78,7 +80,9 @@ def _load_output_parser(config: dict) -> dict:
_config = config.pop("output_parser")
output_parser_type = _config.pop("_type")
if output_parser_type == "regex_parser":
output_parser = RegexParser(**_config)
output_parser: BaseOutputParser = RegexParser(**_config)
elif output_parser_type == "bash":
output_parser = BashOutputParser(**config)
else:
raise ValueError(f"Unsupported output parser {output_parser_type}")
config["output_parser"] = output_parser

View File

@@ -4,7 +4,8 @@ import sys
import pytest
from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
from langchain.output_parsers.bash import BashOutputParser
from langchain.schema import OutputParserException
from tests.unit_tests.llms.fake_llm import FakeLLM