mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
Harrison/semantic subset (#1079)
Co-authored-by: Chen Wu (吴尘) <henrychenwu@cmu.edu>
This commit is contained in:
parent
19c2797bed
commit
c96ac3e591
@ -24,6 +24,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Number of examples to select."""
|
||||
example_keys: Optional[List[str]] = None
|
||||
"""Optional keys to filter examples to."""
|
||||
input_keys: Optional[List[str]] = None
|
||||
"""Optional keys to filter input to. If provided, the search is based on
|
||||
the input variables instead of all variables."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -33,13 +36,20 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> str:
|
||||
"""Add new example to vectorstore."""
|
||||
string_example = " ".join(sorted_values(example))
|
||||
if self.input_keys:
|
||||
string_example = " ".join(
|
||||
sorted_values({key: example[key] for key in self.input_keys})
|
||||
)
|
||||
else:
|
||||
string_example = " ".join(sorted_values(example))
|
||||
ids = self.vectorstore.add_texts([string_example], metadatas=[example])
|
||||
return ids[0]
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
if self.input_keys:
|
||||
input_variables = {key: input_variables[key] for key in self.input_keys}
|
||||
query = " ".join(sorted_values(input_variables))
|
||||
example_docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
# Get the examples from the metadata.
|
||||
@ -57,6 +67,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: VectorStore,
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
"""Create k-shot example selector using example list and embeddings.
|
||||
@ -68,16 +79,24 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
|
||||
k: Number of examples to select
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
if input_keys:
|
||||
string_examples = [
|
||||
" ".join(sorted_values({k: eg[k] for k in input_keys}))
|
||||
for eg in examples
|
||||
]
|
||||
else:
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(vectorstore=vectorstore, k=k)
|
||||
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
|
||||
|
||||
|
||||
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, BaseModel):
|
||||
@ -93,6 +112,8 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
if self.input_keys:
|
||||
input_variables = {key: input_variables[key] for key in self.input_keys}
|
||||
query = " ".join(sorted_values(input_variables))
|
||||
example_docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, k=self.k, fetch_k=self.fetch_k
|
||||
@ -112,6 +133,7 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: VectorStore,
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
@ -124,13 +146,21 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, Bas
|
||||
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
|
||||
k: Number of examples to select
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
if input_keys:
|
||||
string_examples = [
|
||||
" ".join(sorted_values({k: eg[k] for k in input_keys}))
|
||||
for eg in examples
|
||||
]
|
||||
else:
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k)
|
||||
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys)
|
||||
|
Loading…
Reference in New Issue
Block a user