Merge pull request #11 from VowpalWabbit/add_notebook

add random policy and notebook example
This commit is contained in:
olgavrou 2023-09-05 09:36:20 -04:00 committed by GitHub
commit 3a4c895280
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 851 additions and 3 deletions

File diff suppressed because one or more lines are too long

View File

@ -16,6 +16,7 @@ from langchain.chains.rl_chain.pick_best_chain import (
PickBest,
PickBestEvent,
PickBestFeatureEmbedder,
PickBestRandomPolicy,
PickBestSelected,
)
@ -39,6 +40,7 @@ __all__ = [
"PickBestEvent",
"PickBestSelected",
"PickBestFeatureEmbedder",
"PickBestRandomPolicy",
"Embed",
"BasedOn",
"ToSelectFrom",

View File

@ -166,7 +166,7 @@ class Event(Generic[TSelected], ABC):
TEvent = TypeVar("TEvent", bound=Event)
class Policy(ABC):
class Policy(Generic[TEvent], ABC):
def __init__(self, **kwargs: Any):
pass

View File

@ -223,6 +223,21 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
return self.format_auto_embed_off(event)
class PickBestRandomPolicy(base.Policy[PickBestEvent]):
def __init__(self, feature_embedder: base.Embedder, **kwargs: Any):
self.feature_embedder = feature_embedder
def predict(self, event: PickBestEvent) -> List[Tuple[int, float]]:
num_items = len(event.to_select_from)
return [(i, 1.0 / num_items) for i in range(num_items)]
def learn(self, event: PickBestEvent) -> None:
pass
def log(self, event: PickBestEvent) -> None:
pass
class PickBest(base.RLChain[PickBestEvent]):
"""
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.

View File

@ -10939,7 +10939,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
all = ["O365", "aleph-alpha-client", "amadeus", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jinja2", "jq", "lancedb", "langkit", "lark", "libdeeplake", "librosa", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
all = ["O365", "aleph-alpha-client", "amadeus", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jinja2", "jq", "lancedb", "langkit", "lark", "libdeeplake", "librosa", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "vowpal-wabbit-next", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"]
clarifai = ["clarifai"]
cohere = ["cohere"]
@ -10955,4 +10955,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "71842b0ce1bd5c663e96a8ef14f71ce42667833cab72de4273ca07241c4465a9"
content-hash = "7bffde1b8d57bad4b5a48d73250cb8276eb7e40dfe19f8490d5f4a25cb15322d"

View File

@ -295,6 +295,7 @@ all = [
"amadeus",
"librosa",
"python-arango",
"vowpal-wabbit-next",
]
# An extra used to be able to add extended testing.