mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Add add_example method to all ExampleSelector classes, with tests (#178)
Also updated docs, and noticed an issue with the add_texts method on VectorStores that I had missed before -- the metadatas arg should be required to match the classmethod which initializes the VectorStores (the add_example methods break otherwise in the ExampleSelectors)
This commit is contained in:
parent
780ef84cf0
commit
09f301cd38
@ -310,7 +310,7 @@
|
||||
" example_prompt=example_prompt, \n",
|
||||
" # This is the maximum length that the formatted examples should be.\n",
|
||||
" # Length is measured by the get_text_length function below.\n",
|
||||
" max_length=18,\n",
|
||||
" max_length=25,\n",
|
||||
" # This is the function used to get the length of a string, which is used\n",
|
||||
" # to determine which examples to include. It is commented out because\n",
|
||||
" # it is provided as a default value if none is specified.\n",
|
||||
@ -378,17 +378,59 @@
|
||||
"Input: happy\n",
|
||||
"Output: sad\n",
|
||||
"\n",
|
||||
"Input: big and huge and massive and large and gigantic and tall and bigger than everything else\n",
|
||||
"Input: big and huge and massive and large and gigantic and tall and much much much much much bigger than everything else\n",
|
||||
"Output:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# An example with long input, so it selects only one example.\n",
|
||||
"long_string = \"big and huge and massive and large and gigantic and tall and bigger than everything else\"\n",
|
||||
"long_string = \"big and huge and massive and large and gigantic and tall and much much much much much bigger than everything else\"\n",
|
||||
"print(dynamic_prompt.format(adjective=long_string))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "e4bebcd9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Give the antonym of every input\n",
|
||||
"\n",
|
||||
"Input: happy\n",
|
||||
"Output: sad\n",
|
||||
"\n",
|
||||
"Input: tall\n",
|
||||
"Output: short\n",
|
||||
"\n",
|
||||
"Input: energetic\n",
|
||||
"Output: lethargic\n",
|
||||
"\n",
|
||||
"Input: sunny\n",
|
||||
"Output: gloomy\n",
|
||||
"\n",
|
||||
"Input: windy\n",
|
||||
"Output: calm\n",
|
||||
"\n",
|
||||
"Input: big\n",
|
||||
"Output: small\n",
|
||||
"\n",
|
||||
"Input: enthusiastic\n",
|
||||
"Output:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# You can add an example to an example selector as well.\n",
|
||||
"new_example = {\"input\": \"big\", \"output\": \"small\"}\n",
|
||||
"dynamic_prompt.example_selector.add_example(new_example)\n",
|
||||
"print(dynamic_prompt.format(adjective=\"enthusiastic\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d007b0a",
|
||||
@ -401,7 +443,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"id": "241bfe80",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -413,7 +455,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 15,
|
||||
"id": "50d0a701",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -440,7 +482,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 16,
|
||||
"id": "4c8fdf45",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -465,9 +507,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 17,
|
||||
"id": "829af21a",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -484,10 +528,36 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Input is a measurment, so should select the tall/short example\n",
|
||||
"# Input is a measurement, so should select the tall/short example\n",
|
||||
"print(similar_prompt.format(adjective=\"fat\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "3c16fe23",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Give the antonym of every input\n",
|
||||
"\n",
|
||||
"Input: enthusiastic\n",
|
||||
"Output: apathetic\n",
|
||||
"\n",
|
||||
"Input: joyful\n",
|
||||
"Output:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# You can add new examples to the SemanticSimilarityExampleSelector as well\n",
|
||||
"similar_prompt.example_selector.add_example({\"input\": \"enthusiastic\", \"output\": \"apathetic\"})\n",
|
||||
"print(similar_prompt.format(adjective=\"joyful\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dbc32551",
|
||||
@ -532,7 +602,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -6,6 +6,10 @@ from typing import Dict, List
|
||||
class BaseExampleSelector(ABC):
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
@abstractmethod
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to store for a key."""
|
||||
|
||||
@abstractmethod
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on the inputs."""
|
||||
|
@ -25,6 +25,12 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
|
||||
example_text_lengths: List[int] = [] #: :meta private:
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to list."""
|
||||
self.examples.append(example)
|
||||
string_example = self.example_prompt.format(**example)
|
||||
self.example_text_lengths.append(self.get_text_length(string_example))
|
||||
|
||||
@validator("example_text_lengths", always=True)
|
||||
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]:
|
||||
"""Calculate text lengths if they don't exist."""
|
||||
|
@ -24,6 +24,11 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to vectorstore."""
|
||||
string_example = " ".join(example.values())
|
||||
self.vectorstore.add_texts([string_example], metadatas=[example])
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
|
@ -10,7 +10,9 @@ class VectorStore(ABC):
|
||||
"""Interface for vector stores."""
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
def add_texts(
|
||||
self, texts: Iterable[str], metadatas: Optional[List[dict]] = None
|
||||
) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -65,7 +65,9 @@ class ElasticVectorSearch(VectorStore):
|
||||
)
|
||||
self.client = es_client
|
||||
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
def add_texts(
|
||||
self, texts: Iterable[str], metadatas: Optional[List[dict]] = None
|
||||
) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
try:
|
||||
from elasticsearch.helpers import bulk
|
||||
@ -76,11 +78,13 @@ class ElasticVectorSearch(VectorStore):
|
||||
)
|
||||
requests = []
|
||||
for i, text in enumerate(texts):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
request = {
|
||||
"_op_type": "index",
|
||||
"_index": self.index_name,
|
||||
"vector": self.embedding_function(text),
|
||||
"text": text,
|
||||
"metadata": metadata,
|
||||
}
|
||||
requests.append(request)
|
||||
bulk(self.client, requests)
|
||||
|
@ -37,7 +37,9 @@ class FAISS(VectorStore):
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
def add_texts(
|
||||
self, texts: Iterable[str], metadatas: Optional[List[dict]] = None
|
||||
) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError(
|
||||
@ -46,7 +48,10 @@ class FAISS(VectorStore):
|
||||
)
|
||||
# Embed and create the documents.
|
||||
embeddings = [self.embedding_function(text) for text in texts]
|
||||
documents = [Document(page_content=text) for text in texts]
|
||||
documents = []
|
||||
for i, text in enumerate(texts):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
# Add to the index, the index_to_id mapping, and the docstore.
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
self.index.add(np.array(embeddings, dtype=np.float32))
|
||||
|
@ -29,6 +29,15 @@ def test_dynamic_prompt_valid(selector: LengthBasedExampleSelector) -> None:
|
||||
assert output == EXAMPLES
|
||||
|
||||
|
||||
def test_dynamic_prompt_add_example(selector: LengthBasedExampleSelector) -> None:
|
||||
"""Test dynamic prompt can add an example."""
|
||||
new_example = {"question": "Question: what are you?\nAnswer: bar"}
|
||||
selector.add_example(new_example)
|
||||
short_question = "Short question?"
|
||||
output = selector.select_examples({"question": short_question})
|
||||
assert output == EXAMPLES + [new_example]
|
||||
|
||||
|
||||
def test_dynamic_prompt_trims_one_example(selector: LengthBasedExampleSelector) -> None:
|
||||
"""Test dynamic prompt can trim one example."""
|
||||
long_question = """I am writing a really long question,
|
||||
|
Loading…
Reference in New Issue
Block a user