core[minor]: generation info on msg (#18592)

related to #16403 #17188
This commit is contained in:
Bagatur
2024-03-11 21:43:17 -07:00
committed by GitHub
parent cda43c5a11
commit e0e688a277
12 changed files with 357 additions and 164 deletions

View File

@@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "78b45321-7740-4399-b2ad-459811131de3",
"metadata": {},
"source": [
"# Get log probabilities\n",
"\n",
"Certain chat models can be configured to return token-level log probabilities. This guide walks through how to get logprobs for a number of models."
]
},
{
"cell_type": "markdown",
"id": "7f5016bf-2a7b-4140-9b80-8c35c7e5c0d5",
"metadata": {},
"source": [
"## OpenAI\n",
"\n",
"Install the LangChain x OpenAI package and set your API key"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe5143fe-84d3-4a91-bae8-629807bbe2cb",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-openai"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd1a2bff-7ac8-46cb-ab95-72c616b45f2c",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "markdown",
"id": "f88ffa0d-f4a7-482c-88de-cbec501a79b1",
"metadata": {},
"source": [
"For the OpenAI API to return log probabilities we need to configure the `logprobs=True` param"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d1bf0a9a-e402-4931-ab53-32899f8e0326",
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\").bind(logprobs=True)\n",
"\n",
"msg = llm.invoke((\"human\", \"how are you today\"))"
]
},
{
"cell_type": "markdown",
"id": "e002c48a-af03-4796-a367-a69c5c8ae0c4",
"metadata": {},
"source": [
"The logprobs are included on each output Message as part of the `response_metadata`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e3e17872-62df-4b17-a8d4-4cae713a301b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'token': 'As',\n",
" 'bytes': [65, 115],\n",
" 'logprob': -1.5358024,\n",
" 'top_logprobs': []},\n",
" {'token': ' an',\n",
" 'bytes': [32, 97, 110],\n",
" 'logprob': -0.028062303,\n",
" 'top_logprobs': []},\n",
" {'token': ' AI',\n",
" 'bytes': [32, 65, 73],\n",
" 'logprob': -0.009415812,\n",
" 'top_logprobs': []},\n",
" {'token': ',', 'bytes': [44], 'logprob': -0.07371779, 'top_logprobs': []},\n",
" {'token': ' I',\n",
" 'bytes': [32, 73],\n",
" 'logprob': -4.298773e-05,\n",
" 'top_logprobs': []}]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"msg.response_metadata[\"logprobs\"][\"content\"][:5]"
]
},
{
"cell_type": "markdown",
"id": "d1ee1c29-d27e-4353-8c3c-2ed7e7f95ff5",
"metadata": {},
"source": [
"And are part of streamed Message chunks as well:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4bfaf309-3b23-43b7-b333-01fc4848992d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}, {'token': ',', 'bytes': [44], 'logprob': -0.08852102, 'top_logprobs': []}]\n"
]
}
],
"source": [
"ct = 0\n",
"full = None\n",
"for chunk in llm.stream((\"human\", \"how are you today\")):\n",
" if ct < 5:\n",
" full = chunk if full is None else full + chunk\n",
" if \"logprobs\" in full.response_metadata:\n",
" print(full.response_metadata[\"logprobs\"][\"content\"])\n",
" else:\n",
" break\n",
" ct += 1"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv-2",
"language": "python",
"name": "poetry-venv-2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}