core[minor]: generation info on msg (#18592)

related to #16403 #17188
This commit is contained in:
Bagatur 2024-03-11 21:43:17 -07:00 committed by GitHub
parent cda43c5a11
commit e0e688a277
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 357 additions and 164 deletions

View File

@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "78b45321-7740-4399-b2ad-459811131de3",
"metadata": {},
"source": [
"# Get log probabilities\n",
"\n",
"Certain chat models can be configured to return token-level log probabilities. This guide walks through how to get logprobs for a number of models."
]
},
{
"cell_type": "markdown",
"id": "7f5016bf-2a7b-4140-9b80-8c35c7e5c0d5",
"metadata": {},
"source": [
"## OpenAI\n",
"\n",
"Install the LangChain x OpenAI package and set your API key"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe5143fe-84d3-4a91-bae8-629807bbe2cb",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-openai"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd1a2bff-7ac8-46cb-ab95-72c616b45f2c",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "markdown",
"id": "f88ffa0d-f4a7-482c-88de-cbec501a79b1",
"metadata": {},
"source": [
"For the OpenAI API to return log probabilities we need to configure the `logprobs=True` param"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d1bf0a9a-e402-4931-ab53-32899f8e0326",
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\").bind(logprobs=True)\n",
"\n",
"msg = llm.invoke((\"human\", \"how are you today\"))"
]
},
{
"cell_type": "markdown",
"id": "e002c48a-af03-4796-a367-a69c5c8ae0c4",
"metadata": {},
"source": [
"The logprobs are included on each output Message as part of the `response_metadata`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e3e17872-62df-4b17-a8d4-4cae713a301b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'token': 'As',\n",
" 'bytes': [65, 115],\n",
" 'logprob': -1.5358024,\n",
" 'top_logprobs': []},\n",
" {'token': ' an',\n",
" 'bytes': [32, 97, 110],\n",
" 'logprob': -0.028062303,\n",
" 'top_logprobs': []},\n",
" {'token': ' AI',\n",
" 'bytes': [32, 65, 73],\n",
" 'logprob': -0.009415812,\n",
" 'top_logprobs': []},\n",
" {'token': ',', 'bytes': [44], 'logprob': -0.07371779, 'top_logprobs': []},\n",
" {'token': ' I',\n",
" 'bytes': [32, 73],\n",
" 'logprob': -4.298773e-05,\n",
" 'top_logprobs': []}]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"msg.response_metadata[\"logprobs\"][\"content\"][:5]"
]
},
{
"cell_type": "markdown",
"id": "d1ee1c29-d27e-4353-8c3c-2ed7e7f95ff5",
"metadata": {},
"source": [
"And are part of streamed Message chunks as well:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4bfaf309-3b23-43b7-b333-01fc4848992d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}, {'token': ',', 'bytes': [44], 'logprob': -0.08852102, 'top_logprobs': []}]\n"
]
}
],
"source": [
"ct = 0\n",
"full = None\n",
"for chunk in llm.stream((\"human\", \"how are you today\")):\n",
" if ct < 5:\n",
" full = chunk if full is None else full + chunk\n",
" if \"logprobs\" in full.response_metadata:\n",
" print(full.response_metadata[\"logprobs\"][\"content\"])\n",
" else:\n",
" break\n",
" ct += 1"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv-2",
"language": "python",
"name": "poetry-venv-2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -15,6 +15,7 @@ from typing import (
List,
Optional,
Sequence,
Union,
cast,
)
@ -240,6 +241,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
if generation is None:
generation = chunk
@ -317,6 +319,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async for chunk in _stream_implementation(
messages, stop=stop, run_manager=run_manager, **kwargs
):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
if generation is None:
generation = chunk
@ -586,38 +589,35 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
check_cache = self.cache or self.cache is None
if check_cache:
if llm_cache:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
elif self.cache is None:
pass
else:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return self._generate(messages, stop=stop, **kwargs)
if inspect.signature(self._generate).parameters.get("run_manager"):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = self._generate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
return result
result = self._generate(messages, stop=stop, **kwargs)
for generation in result.generations:
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if check_cache and llm_cache:
llm_cache.update(prompt, llm_string, result.generations)
return result
async def _agenerate_with_cache(
self,
@ -626,38 +626,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
check_cache = self.cache or self.cache is None
if check_cache:
if llm_cache:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
elif self.cache is None:
pass
else:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return await self._agenerate(messages, stop=stop, **kwargs)
if inspect.signature(self._agenerate).parameters.get("run_manager"):
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
result = await self._agenerate(messages, stop=stop, **kwargs)
for generation in result.generations:
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if check_cache and llm_cache:
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
@abstractmethod
def _generate(
@ -852,3 +848,12 @@ class SimpleChatModel(BaseChatModel):
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
)
def _gen_info_and_msg_metadata(
generation: Union[ChatGeneration, ChatGenerationChunk],
) -> dict:
return {
**(generation.generation_info or {}),
**generation.message.response_metadata,
}

View File

@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
class AIMessage(BaseMessage):
@ -49,9 +50,12 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
return super().__add__(other)

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.utils import get_bolded_text
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
@ -114,54 +115,6 @@ class BaseMessageChunk(BaseMessage):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one,
handling specific scenarios where a key exists in both dictionaries
but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Example:
If left = {"function_call": {"arguments": None}} and
right = {"function_call": {"arguments": "{\n"}}
then, after merging, for the key "function_call",
the value from 'right' is used,
resulting in merged = {"function_call": {"arguments": "{\n"}}.
"""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif merged[k] is None and v:
merged[k] = v
elif v is None:
continue
elif merged[k] == v:
continue
elif type(merged[k]) != type(v):
raise TypeError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
elif isinstance(merged[k], list):
merged[k] = merged[k].copy()
for i, e in enumerate(v):
if isinstance(e, dict) and isinstance(e.get("index"), int):
i = e["index"]
if i < len(merged[k]):
merged[k][i] = self._merge_kwargs_dict(merged[k][i], e)
else:
merged[k] = merged[k] + [e]
else:
raise TypeError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
@ -170,9 +123,12 @@ class BaseMessageChunk(BaseMessage):
return self.__class__( # type: ignore[call-arg]
id=self.id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
else:
raise TypeError(

View File

@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
class ChatMessage(BaseMessage):
@ -47,17 +48,23 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
elif isinstance(other, BaseMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
else:
return super().__add__(other)

View File

@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
class FunctionMessage(BaseMessage):
@ -47,9 +48,12 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
return self.__class__(
name=self.name,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
return super().__add__(other)

View File

@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
class ToolMessage(BaseMessage):
@ -47,9 +48,12 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
return self.__class__(
tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
return super().__add__(other)

View File

@ -16,27 +16,44 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
resulting in merged = {"function_call": {"arguments": "{\n"}}.
"""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif v is not None and merged[k] is None:
merged[k] = v
elif v is None or merged[k] == v:
for right_k, right_v in right.items():
if right_k not in merged:
merged[right_k] = right_v
elif right_v is not None and merged[right_k] is None:
merged[right_k] = right_v
elif right_v is None:
continue
elif type(merged[k]) != type(v):
elif type(merged[right_k]) != type(right_v):
raise TypeError(
f'additional_kwargs["{k}"] already exists in this message,'
f'additional_kwargs["{right_k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = merge_dicts(merged[k], v)
elif isinstance(merged[k], list):
merged[k] = merged[k] + v
elif isinstance(merged[right_k], str):
merged[right_k] += right_v
elif isinstance(merged[right_k], dict):
merged[right_k] = merge_dicts(merged[right_k], right_v)
elif isinstance(merged[right_k], list):
merged[right_k] = merged[right_k].copy()
for e in right_v:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
to_merge = [
i
for i, e_left in enumerate(merged[right_k])
if e_left["index"] == e["index"]
]
if to_merge:
merged[right_k][to_merge[0]] = merge_dicts(
merged[right_k][to_merge[0]], e
)
else:
merged[right_k] = merged[right_k] + [e]
else:
merged[right_k] = merged[right_k] + [e]
elif merged[right_k] == right_v:
continue
else:
raise TypeError(
f"Additional kwargs key {k} already exists in left dict and value has "
f"unsupported type {type(merged[k])}."
f"Additional kwargs key {right_k} already exists in left dict and "
f"value has unsupported type {type(merged[right_k])}."
)
return merged

View File

@ -48,9 +48,9 @@ def test_check_package_version(
({"a": 1.5}, {"a": 1.5}, {"a": 1.5}),
({"a": True}, {"a": True}, {"a": True}),
({"a": False}, {"a": False}, {"a": False}),
({"a": "txt"}, {"a": "txt"}, {"a": "txt"}),
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2]}),
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txt"}}),
({"a": "txt"}, {"a": "txt"}, {"a": "txttxt"}),
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2, 1, 2]}),
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txttxt"}}),
# Merge strings.
({"a": "one"}, {"a": "two"}, {"a": "onetwo"}),
# Merge dicts.
@ -89,6 +89,17 @@ def test_check_package_version(
),
),
),
# 'index' keyword has special handling
(
{"a": [{"index": 0, "b": "{"}]},
{"a": [{"index": 0, "b": "f"}]},
{"a": [{"index": 0, "b": "{f"}]},
),
(
{"a": [{"idx": 0, "b": "{"}]},
{"a": [{"idx": 0, "b": "f"}]},
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
),
),
)
def test_merge_dicts(

View File

@ -14,10 +14,7 @@ from langchain_core.outputs import ChatGeneration, Generation
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from langchain.cache import (
InMemoryCache,
SQLAlchemyCache,
)
from langchain.cache import InMemoryCache, SQLAlchemyCache
from langchain.globals import get_llm_cache, set_llm_cache
@ -67,7 +64,7 @@ async def test_llm_caching() -> None:
llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)],
)
assert llm(prompt) == cached_response
assert llm.invoke(prompt) == cached_response
# async test
await llm_cache.aupdate(
prompt=prompt,

View File

@ -1,10 +1,10 @@
"""Test AzureChatOpenAI wrapper."""
import os
from typing import Any
from typing import Any, Optional
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_openai import AzureChatOpenAI
@ -164,16 +164,20 @@ async def test_async_chat_openai_streaming() -> None:
@pytest.mark.scheduled
def test_openai_streaming(llm: AzureChatOpenAI) -> None:
"""Test streaming tokens from OpenAI."""
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
@pytest.mark.scheduled
async def test_openai_astream(llm: AzureChatOpenAI) -> None:
"""Test streaming tokens from OpenAI."""
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
@pytest.mark.scheduled

View File

@ -1,9 +1,15 @@
"""Test ChatOpenAI chat model."""
from typing import Any, Optional
from typing import Any, Optional, cast
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatResult,
@ -99,7 +105,7 @@ def test_chat_openai_streaming() -> None:
verbose=True,
)
message = HumanMessage(content="Hello")
response = chat([message])
response = chat.invoke([message])
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)
@ -336,16 +342,20 @@ def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI()
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI()
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
async def test_abatch() -> None:
@ -395,33 +405,33 @@ def test_invoke() -> None:
def test_logprobs() -> None:
llm = ChatOpenAI()
result = llm.generate([[HumanMessage(content="I'm PickleRick")]], logprobs=True)
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]
result = llm.invoke([HumanMessage(content="I'm PickleRick")], logprobs=True)
assert result.response_metadata
assert "content" in result.response_metadata["logprobs"]
async def test_async_logprobs() -> None:
llm = ChatOpenAI()
result = await llm.agenerate(
[[HumanMessage(content="I'm PickleRick")]], logprobs=True
)
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]
result = await llm.ainvoke([HumanMessage(content="I'm PickleRick")], logprobs=True)
assert result.response_metadata
assert "content" in result.response_metadata["logprobs"]
def test_logprobs_streaming() -> None:
llm = ChatOpenAI()
result = llm.generate(
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True
)
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("I'm Pickle Rick", logprobs=True):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert cast(BaseMessageChunk, full).response_metadata
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
async def test_async_logprobs_streaming() -> None:
llm = ChatOpenAI()
result = await llm.agenerate(
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True
)
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("I'm Pickle Rick", logprobs=True):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert cast(BaseMessageChunk, full).response_metadata
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]