mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
Merge branch 'master' into harrison/chain_pipeline
This commit is contained in:
commit
3fcc803880
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@ -1,6 +1,6 @@
|
|||||||
name: lint
|
name: lint
|
||||||
|
|
||||||
on: [push, pull_request_target]
|
on: [push, pull_request]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -1,6 +1,6 @@
|
|||||||
name: test
|
name: test
|
||||||
|
|
||||||
on: [push, pull_request_target]
|
on: [push, pull_request]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
@ -37,7 +37,7 @@ This project was largely inspired by a few projects seen on Twitter for which we
|
|||||||
|
|
||||||
**[Self-ask-with-search](https://ofir.io/self-ask.pdf)**
|
**[Self-ask-with-search](https://ofir.io/self-ask.pdf)**
|
||||||
|
|
||||||
To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/self_ask_with_search.ipynb).
|
To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/self_ask_with_search.ipynb).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
|
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
|
||||||
@ -52,7 +52,7 @@ self_ask_with_search.run("What is the hometown of the reigning men's U.S. Open c
|
|||||||
|
|
||||||
**[LLM Math](https://twitter.com/amasad/status/1568824744367259648?s=20&t=-7wxpXBJinPgDuyHLouP1w)**
|
**[LLM Math](https://twitter.com/amasad/status/1568824744367259648?s=20&t=-7wxpXBJinPgDuyHLouP1w)**
|
||||||
|
|
||||||
To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/llm_math.ipynb).
|
To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/llm_math.ipynb).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from langchain import OpenAI, LLMMathChain
|
from langchain import OpenAI, LLMMathChain
|
||||||
@ -65,7 +65,7 @@ llm_math.run("How many of the integers between 0 and 99 inclusive are divisible
|
|||||||
|
|
||||||
**Generic Prompting**
|
**Generic Prompting**
|
||||||
|
|
||||||
You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/simple_prompts.ipynb).
|
You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/simple_prompts.ipynb).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from langchain import Prompt, OpenAI, LLMChain
|
from langchain import Prompt, OpenAI, LLMChain
|
||||||
@ -84,7 +84,7 @@ llm_chain.predict(question=question)
|
|||||||
|
|
||||||
**Embed & Search Documents**
|
**Embed & Search Documents**
|
||||||
|
|
||||||
We support two vector databases to store and search embeddings -- FAISS and Elasticsearch. Here's a code snippet showing how to use FAISS to store embeddings and search for text similar to a query. Both database backends are featured in this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/embeddings.ipynb).
|
We support two vector databases to store and search embeddings -- FAISS and Elasticsearch. Here's a code snippet showing how to use FAISS to store embeddings and search for text similar to a query. Both database backends are featured in this [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/integrations/embeddings.ipynb).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
@ -42,7 +42,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model_lab = ModelLaboratory(llms)"
|
"model_lab = ModelLaboratory.from_llms(llms)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -60,19 +60,19 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mOpenAI\u001b[0m\n",
|
"\u001b[1mOpenAI\u001b[0m\n",
|
||||||
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
||||||
"\u001b[104m\n",
|
"\u001b[36;1m\u001b[1;3m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Flamingos are pink.\u001b[0m\n",
|
"Flamingos are pink.\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mCohere\u001b[0m\n",
|
"\u001b[1mCohere\u001b[0m\n",
|
||||||
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
||||||
"\u001b[103m\n",
|
"\u001b[33;1m\u001b[1;3m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Pink\u001b[0m\n",
|
"Pink\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mHuggingFaceHub\u001b[0m\n",
|
"\u001b[1mHuggingFaceHub\u001b[0m\n",
|
||||||
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
|
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
|
||||||
"\u001b[101mpink\u001b[0m\n",
|
"\u001b[38;5;200m\u001b[1;3mpink\u001b[0m\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -89,7 +89,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n",
|
"prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n",
|
||||||
"model_lab_with_prompt = ModelLaboratory(llms, prompt=prompt)"
|
"model_lab_with_prompt = ModelLaboratory.from_llms(llms, prompt=prompt)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -107,19 +107,19 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mOpenAI\u001b[0m\n",
|
"\u001b[1mOpenAI\u001b[0m\n",
|
||||||
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
||||||
"\u001b[104m\n",
|
"\u001b[36;1m\u001b[1;3m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The capital of New York is Albany.\u001b[0m\n",
|
"The capital of New York is Albany.\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mCohere\u001b[0m\n",
|
"\u001b[1mCohere\u001b[0m\n",
|
||||||
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
||||||
"\u001b[103m\n",
|
"\u001b[33;1m\u001b[1;3m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The capital of New York is Albany.\u001b[0m\n",
|
"The capital of New York is Albany.\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mHuggingFaceHub\u001b[0m\n",
|
"\u001b[1mHuggingFaceHub\u001b[0m\n",
|
||||||
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
|
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
|
||||||
"\u001b[101mst john s\u001b[0m\n",
|
"\u001b[38;5;200m\u001b[1;3mst john s\u001b[0m\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -130,10 +130,103 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 7,
|
||||||
"id": "54336dbf",
|
"id": "54336dbf",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain import SelfAskWithSearchChain, SerpAPIChain\n",
|
||||||
|
"\n",
|
||||||
|
"open_ai_llm = OpenAI(temperature=0)\n",
|
||||||
|
"search = SerpAPIChain()\n",
|
||||||
|
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
|
||||||
|
"\n",
|
||||||
|
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
|
||||||
|
"search = SerpAPIChain()\n",
|
||||||
|
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "6a50a9f1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chains = [self_ask_with_search_openai, self_ask_with_search_cohere]\n",
|
||||||
|
"names = [str(open_ai_llm), str(cohere_llm)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "d3549e99",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_lab = ModelLaboratory(chains, names=names)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "362f7f57",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\u001b[1mInput:\u001b[0m\n",
|
||||||
|
"What is the hometown of the reigning men's U.S. Open champion?\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1mOpenAI\u001b[0m\n",
|
||||||
|
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"What is the hometown of the reigning men's U.S. Open champion?\n",
|
||||||
|
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
|
||||||
|
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
|
||||||
|
"Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"Follow up: Where is Carlos Alcaraz from?\u001b[0m\n",
|
||||||
|
"Intermediate answer: \u001b[33;1m\u001b[1;3mEl Palmar, Spain.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"So the final answer is: El Palmar, Spain\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\u001b[36;1m\u001b[1;3m\n",
|
||||||
|
"So the final answer is: El Palmar, Spain\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1mCohere\u001b[0m\n",
|
||||||
|
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 256, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"What is the hometown of the reigning men's U.S. Open champion?\n",
|
||||||
|
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
|
||||||
|
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
|
||||||
|
"Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"So the final answer is:\n",
|
||||||
|
"\n",
|
||||||
|
"Carlos Alcaraz\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\u001b[33;1m\u001b[1;3m\n",
|
||||||
|
"So the final answer is:\n",
|
||||||
|
"\n",
|
||||||
|
"Carlos Alcaraz\u001b[0m\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model_lab.compare(\"What is the hometown of the reigning men's U.S. Open champion?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "94159131",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -1 +1 @@
|
|||||||
0.0.13
|
0.0.16
|
||||||
|
@ -9,7 +9,7 @@ class Chain(BaseModel, ABC):
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
"""Whether to print out the code that was executed."""
|
"""Whether to print out response text."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -49,6 +49,10 @@ class Chain(BaseModel, ABC):
|
|||||||
self._validate_outputs(outputs)
|
self._validate_outputs(outputs)
|
||||||
return {**inputs, **outputs}
|
return {**inputs, **outputs}
|
||||||
|
|
||||||
|
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||||
|
"""Call the chain on all inputs in the list."""
|
||||||
|
return [self(inputs) for inputs in input_list]
|
||||||
|
|
||||||
def run(self, text: str) -> str:
|
def run(self, text: str) -> str:
|
||||||
"""Run text in, text out (if applicable)."""
|
"""Run text in, text out (if applicable)."""
|
||||||
if len(self.input_keys) != 1:
|
if len(self.input_keys) != 1:
|
||||||
|
@ -59,16 +59,13 @@ class MapReduceChain(Chain, BaseModel):
|
|||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
# Split the larger text into smaller chunks.
|
# Split the larger text into smaller chunks.
|
||||||
docs = self.text_splitter.split_text(
|
docs = self.text_splitter.split_text(inputs[self.input_key])
|
||||||
inputs[self.input_key],
|
|
||||||
)
|
|
||||||
# Now that we have the chunks, we send them to the LLM and track results.
|
# Now that we have the chunks, we send them to the LLM and track results.
|
||||||
# This is the "map" part.
|
# This is the "map" part.
|
||||||
summaries = []
|
input_list = [{self.map_llm.prompt.input_variables[0]: d} for d in docs]
|
||||||
for d in docs:
|
summary_results = self.map_llm.apply(input_list)
|
||||||
inputs = {self.map_llm.prompt.input_variables[0]: d}
|
summaries = [res[self.map_llm.output_key] for res in summary_results]
|
||||||
res = self.map_llm.predict(**inputs)
|
|
||||||
summaries.append(res)
|
|
||||||
|
|
||||||
# We then need to combine these individual parts into one.
|
# We then need to combine these individual parts into one.
|
||||||
# This is the reduce part.
|
# This is the reduce part.
|
||||||
|
@ -28,14 +28,7 @@ class Crawler:
|
|||||||
"Could not import playwright python package. "
|
"Could not import playwright python package. "
|
||||||
"Please it install it with `pip install playwright`."
|
"Please it install it with `pip install playwright`."
|
||||||
)
|
)
|
||||||
self.browser = (
|
self.browser = sync_playwright().start().chromium.launch(headless=False)
|
||||||
sync_playwright()
|
|
||||||
.start()
|
|
||||||
.chromium.launch(
|
|
||||||
headless=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.page = self.browser.new_page()
|
self.page = self.browser.new_page()
|
||||||
self.page.set_viewport_size({"width": 1280, "height": 1080})
|
self.page.set_viewport_size({"width": 1280, "height": 1080})
|
||||||
|
|
||||||
|
@ -109,8 +109,4 @@ Action 3: Finish[yes]""",
|
|||||||
]
|
]
|
||||||
SUFFIX = """\n\nQuestion: {input}"""
|
SUFFIX = """\n\nQuestion: {input}"""
|
||||||
|
|
||||||
PROMPT = Prompt.from_examples(
|
PROMPT = Prompt.from_examples(EXAMPLES, SUFFIX, ["input"])
|
||||||
EXAMPLES,
|
|
||||||
SUFFIX,
|
|
||||||
["input"],
|
|
||||||
)
|
|
||||||
|
@ -38,7 +38,4 @@ Intermediate Answer: New Zealand.
|
|||||||
So the final answer is: No
|
So the final answer is: No
|
||||||
|
|
||||||
Question: {input}"""
|
Question: {input}"""
|
||||||
PROMPT = Prompt(
|
PROMPT = Prompt(input_variables=["input"], template=_DEFAULT_TEMPLATE)
|
||||||
input_variables=["input"],
|
|
||||||
template=_DEFAULT_TEMPLATE,
|
|
||||||
)
|
|
||||||
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class HiddenPrints:
|
class HiddenPrints:
|
||||||
@ -43,7 +44,7 @@ class SerpAPIChain(Chain, BaseModel):
|
|||||||
input_key: str = "search_query" #: :meta private:
|
input_key: str = "search_query" #: :meta private:
|
||||||
output_key: str = "search_result" #: :meta private:
|
output_key: str = "search_result" #: :meta private:
|
||||||
|
|
||||||
serpapi_api_key: Optional[str] = os.environ.get("SERPAPI_API_KEY")
|
serpapi_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -69,14 +70,10 @@ class SerpAPIChain(Chain, BaseModel):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
serpapi_api_key = values.get("serpapi_api_key")
|
serpapi_api_key = get_from_dict_or_env(
|
||||||
|
values, "serpapi_api_key", "SERPAPI_API_KEY"
|
||||||
if serpapi_api_key is None or serpapi_api_key == "":
|
)
|
||||||
raise ValueError(
|
values["serpapi_api_key"] = serpapi_api_key
|
||||||
"Did not find SerpAPI API key, please add an environment variable"
|
|
||||||
" `SERPAPI_API_KEY` which contains it, or pass `serpapi_api_key` "
|
|
||||||
"as a named parameter to the constructor."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
from serpapi import GoogleSearch
|
from serpapi import GoogleSearch
|
||||||
|
|
||||||
|
@ -15,6 +15,5 @@ Only use the following tables:
|
|||||||
|
|
||||||
Question: {input}"""
|
Question: {input}"""
|
||||||
PROMPT = Prompt(
|
PROMPT = Prompt(
|
||||||
input_variables=["input", "table_info", "dialect"],
|
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
|
||||||
template=_DEFAULT_TEMPLATE,
|
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,8 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
vectorstore: VectorStore
|
vectorstore: VectorStore
|
||||||
"""Vector Database to connect to."""
|
"""Vector Database to connect to."""
|
||||||
|
k: int = 4
|
||||||
|
"""Number of documents to query for."""
|
||||||
input_key: str = "query" #: :meta private:
|
input_key: str = "query" #: :meta private:
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||||
docs = self.vectorstore.similarity_search(question)
|
docs = self.vectorstore.similarity_search(question, k=self.k)
|
||||||
contexts = []
|
contexts = []
|
||||||
for j, doc in enumerate(docs):
|
for j, doc in enumerate(docs):
|
||||||
contexts.append(f"Context {j}:\n{doc.page_content}")
|
contexts.append(f"Context {j}:\n{doc.page_content}")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Interface for interacting with a document."""
|
"""Interface for interacting with a document."""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
class Document(BaseModel):
|
||||||
@ -10,6 +10,7 @@ class Document(BaseModel):
|
|||||||
page_content: str
|
page_content: str
|
||||||
lookup_str: str = ""
|
lookup_str: str = ""
|
||||||
lookup_index = 0
|
lookup_index = 0
|
||||||
|
metadata: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def paragraphs(self) -> List[str]:
|
def paragraphs(self) -> List[str]:
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Wrapper around Cohere embedding models."""
|
"""Wrapper around Cohere embedding models."""
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class CohereEmbeddings(BaseModel, Embeddings):
|
class CohereEmbeddings(BaseModel, Embeddings):
|
||||||
@ -25,7 +25,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
model: str = "medium"
|
model: str = "medium"
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY")
|
cohere_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -35,14 +35,9 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
cohere_api_key = values.get("cohere_api_key")
|
cohere_api_key = get_from_dict_or_env(
|
||||||
|
values, "cohere_api_key", "COHERE_API_KEY"
|
||||||
if cohere_api_key is None or cohere_api_key == "":
|
)
|
||||||
raise ValueError(
|
|
||||||
"Did not find Cohere API key, please add an environment variable"
|
|
||||||
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key` as a"
|
|
||||||
" named parameter."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
import cohere
|
import cohere
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Wrapper around OpenAI embedding models."""
|
"""Wrapper around OpenAI embedding models."""
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||||
@ -25,7 +25,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
model_name: str = "babbage"
|
model_name: str = "babbage"
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
|
openai_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -35,14 +35,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
openai_api_key = values.get("openai_api_key")
|
openai_api_key = get_from_dict_or_env(
|
||||||
|
values, "openai_api_key", "OPENAI_API_KEY"
|
||||||
if openai_api_key is None or openai_api_key == "":
|
)
|
||||||
raise ValueError(
|
|
||||||
"Did not find OpenAI API key, please add an environment variable"
|
|
||||||
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a"
|
|
||||||
" named parameter."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
@ -1,14 +1,19 @@
|
|||||||
"""Handle chained inputs."""
|
"""Handle chained inputs."""
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
_COLOR_MAPPING = {"blue": 104, "yellow": 103, "red": 101, "green": 102}
|
_TEXT_COLOR_MAPPING = {
|
||||||
|
"blue": "36;1",
|
||||||
|
"yellow": "33;1",
|
||||||
|
"pink": "38;5;200",
|
||||||
|
"green": "32;1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_color_mapping(
|
def get_color_mapping(
|
||||||
items: List[str], excluded_colors: Optional[List] = None
|
items: List[str], excluded_colors: Optional[List] = None
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Get mapping for items to a support color."""
|
"""Get mapping for items to a support color."""
|
||||||
colors = list(_COLOR_MAPPING.keys())
|
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||||
if excluded_colors is not None:
|
if excluded_colors is not None:
|
||||||
colors = [c for c in colors if c not in excluded_colors]
|
colors = [c for c in colors if c not in excluded_colors]
|
||||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||||
@ -20,8 +25,8 @@ def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
|||||||
if color is None:
|
if color is None:
|
||||||
print(text, end=end)
|
print(text, end=end)
|
||||||
else:
|
else:
|
||||||
color_str = _COLOR_MAPPING[color]
|
color_str = _TEXT_COLOR_MAPPING[color]
|
||||||
print(f"\x1b[{color_str}m{text}\x1b[0m", end=end)
|
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
|
||||||
|
|
||||||
|
|
||||||
class ChainedInput:
|
class ChainedInput:
|
||||||
@ -29,14 +34,14 @@ class ChainedInput:
|
|||||||
|
|
||||||
def __init__(self, text: str, verbose: bool = False):
|
def __init__(self, text: str, verbose: bool = False):
|
||||||
"""Initialize with verbose flag and initial text."""
|
"""Initialize with verbose flag and initial text."""
|
||||||
self.verbose = verbose
|
self._verbose = verbose
|
||||||
if self.verbose:
|
if self._verbose:
|
||||||
print_text(text, None)
|
print_text(text, None)
|
||||||
self._input = text
|
self._input = text
|
||||||
|
|
||||||
def add(self, text: str, color: Optional[str] = None) -> None:
|
def add(self, text: str, color: Optional[str] = None) -> None:
|
||||||
"""Add text to input, print if in verbose mode."""
|
"""Add text to input, print if in verbose mode."""
|
||||||
if self.verbose:
|
if self._verbose:
|
||||||
print_text(text, color)
|
print_text(text, color)
|
||||||
self._input += text
|
self._input += text
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
"""Wrapper around AI21 APIs."""
|
"""Wrapper around AI21 APIs."""
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class AI21PenaltyData(BaseModel):
|
class AI21PenaltyData(BaseModel):
|
||||||
@ -62,7 +62,7 @@ class AI21(BaseModel, LLM):
|
|||||||
logitBias: Optional[Dict[str, float]] = None
|
logitBias: Optional[Dict[str, float]] = None
|
||||||
"""Adjust the probability of specific tokens being generated."""
|
"""Adjust the probability of specific tokens being generated."""
|
||||||
|
|
||||||
ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY")
|
ai21_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -72,14 +72,8 @@ class AI21(BaseModel, LLM):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key exists in environment."""
|
"""Validate that api key exists in environment."""
|
||||||
ai21_api_key = values.get("ai21_api_key")
|
ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
|
||||||
|
values["ai21_api_key"] = ai21_api_key
|
||||||
if ai21_api_key is None or ai21_api_key == "":
|
|
||||||
raise ValueError(
|
|
||||||
"Did not find AI21 API key, please add an environment variable"
|
|
||||||
" `AI21_API_KEY` which contains it, or pass `ai21_api_key`"
|
|
||||||
" as a named parameter."
|
|
||||||
)
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -122,11 +116,7 @@ class AI21(BaseModel, LLM):
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=f"https://api.ai21.com/studio/v1/{self.model}/complete",
|
url=f"https://api.ai21.com/studio/v1/{self.model}/complete",
|
||||||
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
|
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
|
||||||
json={
|
json={"prompt": prompt, "stopSequences": stop, **self._default_params},
|
||||||
"prompt": prompt,
|
|
||||||
"stopSequences": stop,
|
|
||||||
**self._default_params,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
optional_detail = response.json().get("error")
|
optional_detail = response.json().get("error")
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
"""Wrapper around Cohere APIs."""
|
"""Wrapper around Cohere APIs."""
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class Cohere(LLM, BaseModel):
|
class Cohere(LLM, BaseModel):
|
||||||
@ -44,7 +44,7 @@ class Cohere(LLM, BaseModel):
|
|||||||
presence_penalty: int = 0
|
presence_penalty: int = 0
|
||||||
"""Penalizes repeated tokens."""
|
"""Penalizes repeated tokens."""
|
||||||
|
|
||||||
cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY")
|
cohere_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -54,14 +54,9 @@ class Cohere(LLM, BaseModel):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
cohere_api_key = values.get("cohere_api_key")
|
cohere_api_key = get_from_dict_or_env(
|
||||||
|
values, "cohere_api_key", "COHERE_API_KEY"
|
||||||
if cohere_api_key is None or cohere_api_key == "":
|
)
|
||||||
raise ValueError(
|
|
||||||
"Did not find Cohere API key, please add an environment variable"
|
|
||||||
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key`"
|
|
||||||
" as a named parameter."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
import cohere
|
import cohere
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
"""Wrapper around HuggingFace APIs."""
|
"""Wrapper around HuggingFace APIs."""
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
DEFAULT_REPO_ID = "gpt2"
|
DEFAULT_REPO_ID = "gpt2"
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||||
@ -18,7 +18,7 @@ class HuggingFaceHub(LLM, BaseModel):
|
|||||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||||
it as a named parameter to the constructor.
|
it as a named parameter to the constructor.
|
||||||
|
|
||||||
Only supports task `text-generation` for now.
|
Only supports `text-generation` and `text2text-generation` for now.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -35,7 +35,7 @@ class HuggingFaceHub(LLM, BaseModel):
|
|||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
"""Key word arguments to pass to the model."""
|
"""Key word arguments to pass to the model."""
|
||||||
|
|
||||||
huggingfacehub_api_token: Optional[str] = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
huggingfacehub_api_token: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -45,13 +45,9 @@ class HuggingFaceHub(LLM, BaseModel):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
huggingfacehub_api_token = values.get("huggingfacehub_api_token")
|
huggingfacehub_api_token = get_from_dict_or_env(
|
||||||
if huggingfacehub_api_token is None or huggingfacehub_api_token == "":
|
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||||
raise ValueError(
|
)
|
||||||
"Did not find HuggingFace API token, please add an environment variable"
|
|
||||||
" `HUGGINGFACEHUB_API_TOKEN` which contains it, or pass"
|
|
||||||
" `huggingfacehub_api_token` as a named parameter."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub.inference_api import InferenceApi
|
from huggingface_hub.inference_api import InferenceApi
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Wrapper around NLPCloud APIs."""
|
"""Wrapper around NLPCloud APIs."""
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class NLPCloud(LLM, BaseModel):
|
class NLPCloud(LLM, BaseModel):
|
||||||
@ -54,7 +54,7 @@ class NLPCloud(LLM, BaseModel):
|
|||||||
num_return_sequences: int = 1
|
num_return_sequences: int = 1
|
||||||
"""How many completions to generate for each prompt."""
|
"""How many completions to generate for each prompt."""
|
||||||
|
|
||||||
nlpcloud_api_key: Optional[str] = os.environ.get("NLPCLOUD_API_KEY")
|
nlpcloud_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -64,14 +64,9 @@ class NLPCloud(LLM, BaseModel):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
nlpcloud_api_key = values.get("nlpcloud_api_key")
|
nlpcloud_api_key = get_from_dict_or_env(
|
||||||
|
values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
|
||||||
if nlpcloud_api_key is None or nlpcloud_api_key == "":
|
)
|
||||||
raise ValueError(
|
|
||||||
"Did not find NLPCloud API key, please add an environment variable"
|
|
||||||
" `NLPCLOUD_API_KEY` which contains it, or pass `nlpcloud_api_key`"
|
|
||||||
" as a named parameter."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
import nlpcloud
|
import nlpcloud
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Wrapper around OpenAI APIs."""
|
"""Wrapper around OpenAI APIs."""
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(LLM, BaseModel):
|
class OpenAI(LLM, BaseModel):
|
||||||
@ -38,7 +38,7 @@ class OpenAI(LLM, BaseModel):
|
|||||||
best_of: int = 1
|
best_of: int = 1
|
||||||
"""Generates best_of completions server-side and returns the "best"."""
|
"""Generates best_of completions server-side and returns the "best"."""
|
||||||
|
|
||||||
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
|
openai_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -48,14 +48,9 @@ class OpenAI(LLM, BaseModel):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
openai_api_key = values.get("openai_api_key")
|
openai_api_key = get_from_dict_or_env(
|
||||||
|
values, "openai_api_key", "OPENAI_API_KEY"
|
||||||
if openai_api_key is None or openai_api_key == "":
|
)
|
||||||
raise ValueError(
|
|
||||||
"Did not find OpenAI API key, please add an environment variable"
|
|
||||||
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key`"
|
|
||||||
" as a named parameter."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Experiment with different models."""
|
"""Experiment with different models."""
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping, print_text
|
from langchain.input import get_color_mapping, print_text
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
@ -10,7 +11,41 @@ from langchain.prompts.prompt import Prompt
|
|||||||
class ModelLaboratory:
|
class ModelLaboratory:
|
||||||
"""Experiment with different models."""
|
"""Experiment with different models."""
|
||||||
|
|
||||||
def __init__(self, llms: List[LLM], prompt: Optional[Prompt] = None):
|
def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
|
||||||
|
"""Initialize with chains to experiment with.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chains: list of chains to experiment with.
|
||||||
|
"""
|
||||||
|
if not isinstance(chains[0], Chain):
|
||||||
|
raise ValueError(
|
||||||
|
"ModelLaboratory should now be initialized with Chains. "
|
||||||
|
"If you want to initialize with LLMs, use the `from_llms` method "
|
||||||
|
"instead (`ModelLaboratory.from_llms(...)`)"
|
||||||
|
)
|
||||||
|
for chain in chains:
|
||||||
|
if len(chain.input_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently only support chains with one input variable, "
|
||||||
|
f"got {chain.input_keys}"
|
||||||
|
)
|
||||||
|
if len(chain.output_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently only support chains with one output variable, "
|
||||||
|
f"got {chain.output_keys}"
|
||||||
|
)
|
||||||
|
if names is not None:
|
||||||
|
if len(names) != len(chains):
|
||||||
|
raise ValueError("Length of chains does not match length of names.")
|
||||||
|
self.chains = chains
|
||||||
|
chain_range = [str(i) for i in range(len(self.chains))]
|
||||||
|
self.chain_colors = get_color_mapping(chain_range)
|
||||||
|
self.names = names
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llms(
|
||||||
|
cls, llms: List[LLM], prompt: Optional[Prompt] = None
|
||||||
|
) -> "ModelLaboratory":
|
||||||
"""Initialize with LLMs to experiment with and optional prompt.
|
"""Initialize with LLMs to experiment with and optional prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -18,18 +53,11 @@ class ModelLaboratory:
|
|||||||
prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
|
prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
|
||||||
If a prompt was provided, it should only have one input variable.
|
If a prompt was provided, it should only have one input variable.
|
||||||
"""
|
"""
|
||||||
self.llms = llms
|
|
||||||
llm_range = [str(i) for i in range(len(self.llms))]
|
|
||||||
self.llm_colors = get_color_mapping(llm_range)
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
self.prompt = Prompt(input_variables=["_input"], template="{_input}")
|
prompt = Prompt(input_variables=["_input"], template="{_input}")
|
||||||
else:
|
chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
|
||||||
if len(prompt.input_variables) != 1:
|
names = [str(llm) for llm in llms]
|
||||||
raise ValueError(
|
return cls(chains, names=names)
|
||||||
"Currently only support prompts with one input variable, "
|
|
||||||
f"got {prompt}"
|
|
||||||
)
|
|
||||||
self.prompt = prompt
|
|
||||||
|
|
||||||
def compare(self, text: str) -> None:
|
def compare(self, text: str) -> None:
|
||||||
"""Compare model outputs on an input text.
|
"""Compare model outputs on an input text.
|
||||||
@ -42,9 +70,11 @@ class ModelLaboratory:
|
|||||||
text: input text to run all models on.
|
text: input text to run all models on.
|
||||||
"""
|
"""
|
||||||
print(f"\033[1mInput:\033[0m\n{text}\n")
|
print(f"\033[1mInput:\033[0m\n{text}\n")
|
||||||
for i, llm in enumerate(self.llms):
|
for i, chain in enumerate(self.chains):
|
||||||
print_text(str(llm), end="\n")
|
if self.names is not None:
|
||||||
chain = LLMChain(llm=llm, prompt=self.prompt)
|
name = self.names[i]
|
||||||
llm_inputs = {self.prompt.input_variables[0]: text}
|
else:
|
||||||
output = chain.predict(**llm_inputs)
|
name = str(chain)
|
||||||
print_text(output, color=self.llm_colors[str(i)], end="\n\n")
|
print_text(name, end="\n")
|
||||||
|
output = chain.run(text)
|
||||||
|
print_text(output, color=self.chain_colors[str(i)], end="\n\n")
|
||||||
|
@ -94,8 +94,7 @@ class Prompt(BaseModel, BasePrompt):
|
|||||||
Returns:
|
Returns:
|
||||||
The final prompt generated.
|
The final prompt generated.
|
||||||
"""
|
"""
|
||||||
example_str = example_separator.join(examples)
|
template = example_separator.join([prefix, *examples, suffix])
|
||||||
template = prefix + example_str + suffix
|
|
||||||
return cls(input_variables=input_variables, template=template)
|
return cls(input_variables=input_variables, template=template)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""SQLAlchemy wrapper around a database."""
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import create_engine, inspect
|
from sqlalchemy import create_engine, inspect
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
@ -6,29 +8,57 @@ from sqlalchemy.engine import Engine
|
|||||||
class SQLDatabase:
|
class SQLDatabase:
|
||||||
"""SQLAlchemy wrapper around a database."""
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
|
||||||
def __init__(self, engine: Engine):
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine: Engine,
|
||||||
|
ignore_tables: Optional[List[str]] = None,
|
||||||
|
include_tables: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
"""Create engine from database URI."""
|
"""Create engine from database URI."""
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
|
if include_tables and ignore_tables:
|
||||||
|
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||||
|
|
||||||
|
self._inspector = inspect(self._engine)
|
||||||
|
self._all_tables = self._inspector.get_table_names()
|
||||||
|
self._include_tables = include_tables or []
|
||||||
|
if self._include_tables:
|
||||||
|
missing_tables = set(self._include_tables).difference(self._all_tables)
|
||||||
|
if missing_tables:
|
||||||
|
raise ValueError(
|
||||||
|
f"include_tables {missing_tables} not found in database"
|
||||||
|
)
|
||||||
|
self._ignore_tables = ignore_tables or []
|
||||||
|
if self._ignore_tables:
|
||||||
|
missing_tables = set(self._ignore_tables).difference(self._all_tables)
|
||||||
|
if missing_tables:
|
||||||
|
raise ValueError(
|
||||||
|
f"ignore_tables {missing_tables} not found in database"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri(cls, database_uri: str) -> "SQLDatabase":
|
def from_uri(cls, database_uri: str, **kwargs: Any) -> "SQLDatabase":
|
||||||
"""Construct a SQLAlchemy engine from URI."""
|
"""Construct a SQLAlchemy engine from URI."""
|
||||||
return cls(create_engine(database_uri))
|
return cls(create_engine(database_uri), **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dialect(self) -> str:
|
def dialect(self) -> str:
|
||||||
"""Return string representation of dialect to use."""
|
"""Return string representation of dialect to use."""
|
||||||
return self._engine.dialect.name
|
return self._engine.dialect.name
|
||||||
|
|
||||||
|
def _get_table_names(self) -> Iterable[str]:
|
||||||
|
if self._include_tables:
|
||||||
|
return self._include_tables
|
||||||
|
return set(self._all_tables) - set(self._ignore_tables)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def table_info(self) -> str:
|
def table_info(self) -> str:
|
||||||
"""Information about all tables in the database."""
|
"""Information about all tables in the database."""
|
||||||
template = "The '{table_name}' table has columns: {columns}."
|
template = "Table '{table_name}' has columns: {columns}."
|
||||||
tables = []
|
tables = []
|
||||||
inspector = inspect(self._engine)
|
for table_name in self._get_table_names():
|
||||||
for table_name in inspector.get_table_names():
|
|
||||||
columns = []
|
columns = []
|
||||||
for column in inspector.get_columns(table_name):
|
for column in self._inspector.get_columns(table_name):
|
||||||
columns.append(f"{column['name']} ({str(column['type'])})")
|
columns.append(f"{column['name']} ({str(column['type'])})")
|
||||||
column_str = ", ".join(columns)
|
column_str = ", ".join(columns)
|
||||||
table_str = template.format(table_name=table_name, columns=column_str)
|
table_str = template.format(table_name=table_name, columns=column_str)
|
||||||
|
17
langchain/utils.py
Normal file
17
langchain/utils.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""Generic utility functions."""
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> str:
|
||||||
|
"""Get a value from a dictionary or an environment variable."""
|
||||||
|
if key in data and data[key]:
|
||||||
|
return data[key]
|
||||||
|
elif env_key in os.environ and os.environ[env_key]:
|
||||||
|
return os.environ[env_key]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Did not find {key}, please add an environment variable"
|
||||||
|
f" `{env_key}` which contains it, or pass"
|
||||||
|
f" `{key}` as a named parameter."
|
||||||
|
)
|
@ -1,10 +1,10 @@
|
|||||||
"""Wrapper around Elasticsearch vector database."""
|
"""Wrapper around Elasticsearch vector database."""
|
||||||
import os
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
|
||||||
@ -45,10 +45,7 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, elasticsearch_url: str, index_name: str, embedding_function: Callable
|
||||||
elasticsearch_url: str,
|
|
||||||
index_name: str,
|
|
||||||
embedding_function: Callable,
|
|
||||||
):
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize with necessary components."""
|
||||||
try:
|
try:
|
||||||
@ -110,16 +107,9 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
elasticsearch_url="http://localhost:9200"
|
elasticsearch_url="http://localhost:9200"
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
elasticsearch_url = kwargs.get("elasticsearch_url")
|
elasticsearch_url = get_from_dict_or_env(
|
||||||
if not elasticsearch_url:
|
kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
|
||||||
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL")
|
)
|
||||||
|
|
||||||
if elasticsearch_url is None or elasticsearch_url == "":
|
|
||||||
raise ValueError(
|
|
||||||
"Did not find Elasticsearch URL, please add an environment variable"
|
|
||||||
" `ELASTICSEARCH_URL` which contains it, or pass"
|
|
||||||
" `elasticsearch_url` as a named parameter."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
import elasticsearch
|
import elasticsearch
|
||||||
from elasticsearch.helpers import bulk
|
from elasticsearch.helpers import bulk
|
||||||
|
@ -6,9 +6,7 @@ def test_manifest_wrapper() -> None:
|
|||||||
"""Test manifest wrapper."""
|
"""Test manifest wrapper."""
|
||||||
from manifest import Manifest
|
from manifest import Manifest
|
||||||
|
|
||||||
manifest = Manifest(
|
manifest = Manifest(client_name="openai")
|
||||||
client_name="openai",
|
|
||||||
)
|
|
||||||
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0})
|
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0})
|
||||||
output = llm("The capital of New York is:")
|
output = llm("The capital of New York is:")
|
||||||
assert output == "Albany"
|
assert output == "Albany"
|
||||||
|
@ -48,7 +48,7 @@ def test_chained_input_verbose() -> None:
|
|||||||
chained_input.add("baz", color="blue")
|
chained_input.add("baz", color="blue")
|
||||||
sys.stdout = old_stdout
|
sys.stdout = old_stdout
|
||||||
output = mystdout.getvalue()
|
output = mystdout.getvalue()
|
||||||
assert output == "\x1b[104mbaz\x1b[0m"
|
assert output == "\x1b[36;1m\x1b[1;3mbaz\x1b[0m"
|
||||||
assert chained_input.input == "foobarbaz"
|
assert chained_input.input == "foobarbaz"
|
||||||
|
|
||||||
|
|
||||||
@ -70,5 +70,5 @@ def test_get_color_mapping_excluded_colors() -> None:
|
|||||||
"""Test getting of color mapping with excluded colors."""
|
"""Test getting of color mapping with excluded colors."""
|
||||||
items = ["foo", "bar"]
|
items = ["foo", "bar"]
|
||||||
output = get_color_mapping(items, excluded_colors=["blue"])
|
output = get_color_mapping(items, excluded_colors=["blue"])
|
||||||
expected_output = {"foo": "yellow", "bar": "red"}
|
expected_output = {"foo": "yellow", "bar": "pink"}
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
@ -51,8 +51,8 @@ Question: {question}
|
|||||||
Answer:"""
|
Answer:"""
|
||||||
input_variables = ["question"]
|
input_variables = ["question"]
|
||||||
example_separator = "\n\n"
|
example_separator = "\n\n"
|
||||||
prefix = """Test Prompt:\n\n"""
|
prefix = """Test Prompt:"""
|
||||||
suffix = """\n\nQuestion: {question}\nAnswer:"""
|
suffix = """Question: {question}\nAnswer:"""
|
||||||
examples = [
|
examples = [
|
||||||
"""Question: who are you?\nAnswer: foo""",
|
"""Question: who are you?\nAnswer: foo""",
|
||||||
"""Question: what are you?\nAnswer: bar""",
|
"""Question: what are you?\nAnswer: bar""",
|
||||||
|
@ -28,11 +28,11 @@ def test_table_info() -> None:
|
|||||||
db = SQLDatabase(engine)
|
db = SQLDatabase(engine)
|
||||||
output = db.table_info
|
output = db.table_info
|
||||||
expected_output = (
|
expected_output = (
|
||||||
"The 'company' table has columns: company_id (INTEGER), "
|
"Table 'company' has columns: company_id (INTEGER), "
|
||||||
"company_location (VARCHAR).\n"
|
"company_location (VARCHAR).",
|
||||||
"The 'user' table has columns: user_id (INTEGER), user_name (VARCHAR(16))."
|
"Table 'user' has columns: user_id (INTEGER), user_name (VARCHAR(16)).",
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert sorted(output.split("\n")) == sorted(expected_output)
|
||||||
|
|
||||||
|
|
||||||
def test_sql_database_run() -> None:
|
def test_sql_database_run() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user