mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
notebook fmt (#12498)
This commit is contained in:
@@ -50,6 +50,7 @@
|
||||
"# pick and configure the LLM of your choice\n",
|
||||
"\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"llm = OpenAI(model=\"text-davinci-003\")"
|
||||
]
|
||||
},
|
||||
@@ -85,8 +86,8 @@
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"PROMPT = PromptTemplate(\n",
|
||||
" input_variables=[\"meal\", \"text_to_personalize\", \"user\", \"preference\"], \n",
|
||||
" template=PROMPT_TEMPLATE\n",
|
||||
" input_variables=[\"meal\", \"text_to_personalize\", \"user\", \"preference\"],\n",
|
||||
" template=PROMPT_TEMPLATE,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -105,7 +106,7 @@
|
||||
"source": [
|
||||
"import langchain_experimental.rl_chain as rl_chain\n",
|
||||
"\n",
|
||||
"chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)\n"
|
||||
"chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -122,10 +123,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs \\\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs \\\n",
|
||||
" believe you will love it!\",\n",
|
||||
")"
|
||||
]
|
||||
@@ -193,10 +194,10 @@
|
||||
"for _ in range(5):\n",
|
||||
" try:\n",
|
||||
" response = chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" )\n",
|
||||
" except Exception as e:\n",
|
||||
" print(e)\n",
|
||||
@@ -223,12 +224,16 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"scoring_criteria_template = \"Given {preference} rank how good or bad this selection is {meal}\"\n",
|
||||
"scoring_criteria_template = (\n",
|
||||
" \"Given {preference} rank how good or bad this selection is {meal}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = rl_chain.PickBest.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" prompt=PROMPT,\n",
|
||||
" selection_scorer=rl_chain.AutoSelectionScorer(llm=llm, scoring_criteria_template_str=scoring_criteria_template),\n",
|
||||
" selection_scorer=rl_chain.AutoSelectionScorer(\n",
|
||||
" llm=llm, scoring_criteria_template_str=scoring_criteria_template\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -255,14 +260,16 @@
|
||||
],
|
||||
"source": [
|
||||
"response = chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
")\n",
|
||||
"print(response[\"response\"])\n",
|
||||
"selection_metadata = response[\"selection_metadata\"]\n",
|
||||
"print(f\"selected index: {selection_metadata.selected.index}, score: {selection_metadata.selected.score}\")"
|
||||
"print(\n",
|
||||
" f\"selected index: {selection_metadata.selected.index}, score: {selection_metadata.selected.score}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -280,8 +287,8 @@
|
||||
"source": [
|
||||
"class CustomSelectionScorer(rl_chain.SelectionScorer):\n",
|
||||
" def score_response(\n",
|
||||
" self, inputs, llm_response: str, event: rl_chain.PickBestEvent) -> float:\n",
|
||||
"\n",
|
||||
" self, inputs, llm_response: str, event: rl_chain.PickBestEvent\n",
|
||||
" ) -> float:\n",
|
||||
" print(event.based_on)\n",
|
||||
" print(event.to_select_from)\n",
|
||||
"\n",
|
||||
@@ -336,10 +343,10 @@
|
||||
],
|
||||
"source": [
|
||||
"response = chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -370,9 +377,10 @@
|
||||
" return 1.0\n",
|
||||
" else:\n",
|
||||
" return 0.0\n",
|
||||
" def score_response(\n",
|
||||
" self, inputs, llm_response: str, event: rl_chain.PickBestEvent) -> float:\n",
|
||||
"\n",
|
||||
" def score_response(\n",
|
||||
" self, inputs, llm_response: str, event: rl_chain.PickBestEvent\n",
|
||||
" ) -> float:\n",
|
||||
" selected_meal = event.to_select_from[\"meal\"][event.selected.index]\n",
|
||||
"\n",
|
||||
" if \"Tom\" in event.based_on[\"user\"]:\n",
|
||||
@@ -394,7 +402,7 @@
|
||||
" prompt=PROMPT,\n",
|
||||
" selection_scorer=CustomSelectionScorer(),\n",
|
||||
" metrics_step=5,\n",
|
||||
" metrics_window_size=5, # rolling window average\n",
|
||||
" metrics_window_size=5, # rolling window average\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"random_chain = rl_chain.PickBest.from_llm(\n",
|
||||
@@ -402,8 +410,8 @@
|
||||
" prompt=PROMPT,\n",
|
||||
" selection_scorer=CustomSelectionScorer(),\n",
|
||||
" metrics_step=5,\n",
|
||||
" metrics_window_size=5, # rolling window average\n",
|
||||
" policy=rl_chain.PickBestRandomPolicy # set the random policy instead of default\n",
|
||||
" metrics_window_size=5, # rolling window average\n",
|
||||
" policy=rl_chain.PickBestRandomPolicy, # set the random policy instead of default\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -416,29 +424,29 @@
|
||||
"for _ in range(20):\n",
|
||||
" try:\n",
|
||||
" chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" )\n",
|
||||
" random_chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Anna\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Anna\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" )\n",
|
||||
" random_chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Anna\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Anna\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" )\n",
|
||||
" except Exception as e:\n",
|
||||
" print(e)"
|
||||
@@ -477,12 +485,17 @@
|
||||
],
|
||||
"source": [
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"chain.metrics.to_pandas()['score'].plot(label=\"default learning policy\")\n",
|
||||
"random_chain.metrics.to_pandas()['score'].plot(label=\"random selection policy\")\n",
|
||||
"\n",
|
||||
"chain.metrics.to_pandas()[\"score\"].plot(label=\"default learning policy\")\n",
|
||||
"random_chain.metrics.to_pandas()[\"score\"].plot(label=\"random selection policy\")\n",
|
||||
"plt.legend()\n",
|
||||
"\n",
|
||||
"print(f\"The final average score for the default policy, calculated over a rolling window, is: {chain.metrics.to_pandas()['score'].iloc[-1]}\")\n",
|
||||
"print(f\"The final average score for the random policy, calculated over a rolling window, is: {random_chain.metrics.to_pandas()['score'].iloc[-1]}\")"
|
||||
"print(\n",
|
||||
" f\"The final average score for the default policy, calculated over a rolling window, is: {chain.metrics.to_pandas()['score'].iloc[-1]}\"\n",
|
||||
")\n",
|
||||
"print(\n",
|
||||
" f\"The final average score for the random policy, calculated over a rolling window, is: {random_chain.metrics.to_pandas()['score'].iloc[-1]}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -803,10 +816,10 @@
|
||||
")\n",
|
||||
"\n",
|
||||
"chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal=rl_chain.ToSelectFrom(meals),\n",
|
||||
" user=rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference=rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
||||
" text_to_personalize=\"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
|
Reference in New Issue
Block a user