mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-03 05:34:01 +00:00
Merge pull request #11 from VowpalWabbit/add_notebook
add random policy and notebook example
This commit is contained in:
commit
3a4c895280
File diff suppressed because one or more lines are too long
@ -16,6 +16,7 @@ from langchain.chains.rl_chain.pick_best_chain import (
|
||||
PickBest,
|
||||
PickBestEvent,
|
||||
PickBestFeatureEmbedder,
|
||||
PickBestRandomPolicy,
|
||||
PickBestSelected,
|
||||
)
|
||||
|
||||
@ -39,6 +40,7 @@ __all__ = [
|
||||
"PickBestEvent",
|
||||
"PickBestSelected",
|
||||
"PickBestFeatureEmbedder",
|
||||
"PickBestRandomPolicy",
|
||||
"Embed",
|
||||
"BasedOn",
|
||||
"ToSelectFrom",
|
||||
|
@ -166,7 +166,7 @@ class Event(Generic[TSelected], ABC):
|
||||
TEvent = TypeVar("TEvent", bound=Event)
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
class Policy(Generic[TEvent], ABC):
|
||||
def __init__(self, **kwargs: Any):
|
||||
pass
|
||||
|
||||
|
@ -223,6 +223,21 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
||||
return self.format_auto_embed_off(event)
|
||||
|
||||
|
||||
class PickBestRandomPolicy(base.Policy[PickBestEvent]):
|
||||
def __init__(self, feature_embedder: base.Embedder, **kwargs: Any):
|
||||
self.feature_embedder = feature_embedder
|
||||
|
||||
def predict(self, event: PickBestEvent) -> List[Tuple[int, float]]:
|
||||
num_items = len(event.to_select_from)
|
||||
return [(i, 1.0 / num_items) for i in range(num_items)]
|
||||
|
||||
def learn(self, event: PickBestEvent) -> None:
|
||||
pass
|
||||
|
||||
def log(self, event: PickBestEvent) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class PickBest(base.RLChain[PickBestEvent]):
|
||||
"""
|
||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||
|
4
libs/langchain/poetry.lock
generated
4
libs/langchain/poetry.lock
generated
@ -10939,7 +10939,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
||||
cffi = ["cffi (>=1.11)"]
|
||||
|
||||
[extras]
|
||||
all = ["O365", "aleph-alpha-client", "amadeus", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jinja2", "jq", "lancedb", "langkit", "lark", "libdeeplake", "librosa", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||
all = ["O365", "aleph-alpha-client", "amadeus", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jinja2", "jq", "lancedb", "langkit", "lark", "libdeeplake", "librosa", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "vowpal-wabbit-next", "weaviate-client", "wikipedia", "wolframalpha"]
|
||||
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"]
|
||||
clarifai = ["clarifai"]
|
||||
cohere = ["cohere"]
|
||||
@ -10955,4 +10955,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "71842b0ce1bd5c663e96a8ef14f71ce42667833cab72de4273ca07241c4465a9"
|
||||
content-hash = "7bffde1b8d57bad4b5a48d73250cb8276eb7e40dfe19f8490d5f4a25cb15322d"
|
||||
|
@ -295,6 +295,7 @@ all = [
|
||||
"amadeus",
|
||||
"librosa",
|
||||
"python-arango",
|
||||
"vowpal-wabbit-next",
|
||||
]
|
||||
|
||||
# An extra used to be able to add extended testing.
|
||||
|
Loading…
Reference in New Issue
Block a user