diff --git a/docs/docs/modules/model_io/chat/logprobs.ipynb b/docs/docs/modules/model_io/chat/logprobs.ipynb new file mode 100644 index 00000000000..754f3d126b5 --- /dev/null +++ b/docs/docs/modules/model_io/chat/logprobs.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "78b45321-7740-4399-b2ad-459811131de3", + "metadata": {}, + "source": [ + "# Get log probabilities\n", + "\n", + "Certain chat models can be configured to return token-level log probabilities. This guide walks through how to get logprobs for a number of models." + ] + }, + { + "cell_type": "markdown", + "id": "7f5016bf-2a7b-4140-9b80-8c35c7e5c0d5", + "metadata": {}, + "source": [ + "## OpenAI\n", + "\n", + "Install the LangChain x OpenAI package and set your API key" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe5143fe-84d3-4a91-bae8-629807bbe2cb", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain-openai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd1a2bff-7ac8-46cb-ab95-72c616b45f2c", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" + ] + }, + { + "cell_type": "markdown", + "id": "f88ffa0d-f4a7-482c-88de-cbec501a79b1", + "metadata": {}, + "source": [ + "For the OpenAI API to return log probabilities we need to configure the `logprobs=True` param" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d1bf0a9a-e402-4931-ab53-32899f8e0326", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\").bind(logprobs=True)\n", + "\n", + "msg = llm.invoke((\"human\", \"how are you today\"))" + ] + }, + { + "cell_type": "markdown", + "id": "e002c48a-af03-4796-a367-a69c5c8ae0c4", + "metadata": {}, + "source": [ + "The logprobs are included on each output Message as part of the `response_metadata`:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e3e17872-62df-4b17-a8d4-4cae713a301b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'token': 'As',\n", + " 'bytes': [65, 115],\n", + " 'logprob': -1.5358024,\n", + " 'top_logprobs': []},\n", + " {'token': ' an',\n", + " 'bytes': [32, 97, 110],\n", + " 'logprob': -0.028062303,\n", + " 'top_logprobs': []},\n", + " {'token': ' AI',\n", + " 'bytes': [32, 65, 73],\n", + " 'logprob': -0.009415812,\n", + " 'top_logprobs': []},\n", + " {'token': ',', 'bytes': [44], 'logprob': -0.07371779, 'top_logprobs': []},\n", + " {'token': ' I',\n", + " 'bytes': [32, 73],\n", + " 'logprob': -4.298773e-05,\n", + " 'top_logprobs': []}]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "msg.response_metadata[\"logprobs\"][\"content\"][:5]" + ] + }, + { + "cell_type": "markdown", + "id": "d1ee1c29-d27e-4353-8c3c-2ed7e7f95ff5", + "metadata": {}, + "source": [ + "And are part of streamed Message chunks as well:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4bfaf309-3b23-43b7-b333-01fc4848992d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[]\n", + "[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}]\n", + "[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}]\n", + "[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}]\n", + "[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}, {'token': ',', 'bytes': [44], 'logprob': -0.08852102, 'top_logprobs': []}]\n" + ] + } + ], + "source": [ + "ct = 0\n", + "full = None\n", + "for chunk in llm.stream((\"human\", \"how are you today\")):\n", + " if ct < 5:\n", + " full = chunk if full is None else full + chunk\n", + " if \"logprobs\" in full.response_metadata:\n", + " print(full.response_metadata[\"logprobs\"][\"content\"])\n", + " else:\n", + " break\n", + " ct += 1" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "poetry-venv-2", + "language": "python", + "name": "poetry-venv-2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 99c366ef591..dc8b313fde6 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -15,6 +15,7 @@ from typing import ( List, Optional, Sequence, + Union, cast, ) @@ -240,6 +241,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ): + chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) yield chunk.message if generation is None: generation = chunk @@ -317,6 +319,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): async for chunk in _stream_implementation( messages, stop=stop, run_manager=run_manager, **kwargs ): + chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) yield chunk.message if generation is None: generation = chunk @@ -586,38 +589,35 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - new_arg_supported = inspect.signature(self._generate).parameters.get( - "run_manager" - ) - disregard_cache = self.cache is not None and not self.cache llm_cache = get_llm_cache() - if llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True - if self.cache is not None and self.cache: + check_cache = self.cache or self.cache is None + if check_cache: + if llm_cache: + llm_string = self._get_llm_string(stop=stop, **kwargs) + prompt = dumps(messages) + cache_val = llm_cache.lookup(prompt, llm_string) + if isinstance(cache_val, list): + return ChatResult(generations=cache_val) + elif self.cache is None: + pass + else: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if new_arg_supported: - return self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - return self._generate(messages, stop=stop, **kwargs) + if inspect.signature(self._generate).parameters.get("run_manager"): + result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) else: - llm_string = self._get_llm_string(stop=stop, **kwargs) - prompt = dumps(messages) - cache_val = llm_cache.lookup(prompt, llm_string) - if isinstance(cache_val, list): - return ChatResult(generations=cache_val) - else: - if new_arg_supported: - result = self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - result = self._generate(messages, stop=stop, **kwargs) - llm_cache.update(prompt, llm_string, result.generations) - return result + result = self._generate(messages, stop=stop, **kwargs) + + for generation in result.generations: + generation.message.response_metadata = _gen_info_and_msg_metadata( + generation + ) + if check_cache and llm_cache: + llm_cache.update(prompt, llm_string, result.generations) + return result async def _agenerate_with_cache( self, @@ -626,38 +626,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - new_arg_supported = inspect.signature(self._agenerate).parameters.get( - "run_manager" - ) - disregard_cache = self.cache is not None and not self.cache llm_cache = get_llm_cache() - if llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True - if self.cache is not None and self.cache: + check_cache = self.cache or self.cache is None + if check_cache: + if llm_cache: + llm_string = self._get_llm_string(stop=stop, **kwargs) + prompt = dumps(messages) + cache_val = await llm_cache.alookup(prompt, llm_string) + if isinstance(cache_val, list): + return ChatResult(generations=cache_val) + elif self.cache is None: + pass + else: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if new_arg_supported: - return await self._agenerate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - return await self._agenerate(messages, stop=stop, **kwargs) + if inspect.signature(self._agenerate).parameters.get("run_manager"): + result = await self._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) else: - llm_string = self._get_llm_string(stop=stop, **kwargs) - prompt = dumps(messages) - cache_val = await llm_cache.alookup(prompt, llm_string) - if isinstance(cache_val, list): - return ChatResult(generations=cache_val) - else: - if new_arg_supported: - result = await self._agenerate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - result = await self._agenerate(messages, stop=stop, **kwargs) - await llm_cache.aupdate(prompt, llm_string, result.generations) - return result + result = await self._agenerate(messages, stop=stop, **kwargs) + for generation in result.generations: + generation.message.response_metadata = _gen_info_and_msg_metadata( + generation + ) + if check_cache and llm_cache: + await llm_cache.aupdate(prompt, llm_string, result.generations) + return result @abstractmethod def _generate( @@ -852,3 +848,12 @@ class SimpleChatModel(BaseChatModel): run_manager=run_manager.get_sync() if run_manager else None, **kwargs, ) + + +def _gen_info_and_msg_metadata( + generation: Union[ChatGeneration, ChatGenerationChunk], +) -> dict: + return { + **(generation.generation_info or {}), + **generation.message.response_metadata, + } diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index c667aa54e95..3db98588342 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -5,6 +5,7 @@ from langchain_core.messages.base import ( BaseMessageChunk, merge_content, ) +from langchain_core.utils._merge import merge_dicts class AIMessage(BaseMessage): @@ -49,9 +50,12 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): return self.__class__( example=self.example, content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( + additional_kwargs=merge_dicts( self.additional_kwargs, other.additional_kwargs ), + response_metadata=merge_dicts( + self.response_metadata, other.response_metadata + ), ) return super().__add__(other) diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index fdae1dbf8e4..26d9558e36f 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from langchain_core.load.serializable import Serializable from langchain_core.pydantic_v1 import Extra, Field from langchain_core.utils import get_bolded_text +from langchain_core.utils._merge import merge_dicts from langchain_core.utils.interactive_env import is_interactive_env if TYPE_CHECKING: @@ -114,54 +115,6 @@ class BaseMessageChunk(BaseMessage): """Get the namespace of the langchain object.""" return ["langchain", "schema", "messages"] - def _merge_kwargs_dict( - self, left: Dict[str, Any], right: Dict[str, Any] - ) -> Dict[str, Any]: - """Merge additional_kwargs from another BaseMessageChunk into this one, - handling specific scenarios where a key exists in both dictionaries - but has a value of None in 'left'. In such cases, the method uses the - value from 'right' for that key in the merged dictionary. - Example: - If left = {"function_call": {"arguments": None}} and - right = {"function_call": {"arguments": "{\n"}} - then, after merging, for the key "function_call", - the value from 'right' is used, - resulting in merged = {"function_call": {"arguments": "{\n"}}. - """ - merged = left.copy() - for k, v in right.items(): - if k not in merged: - merged[k] = v - elif merged[k] is None and v: - merged[k] = v - elif v is None: - continue - elif merged[k] == v: - continue - elif type(merged[k]) != type(v): - raise TypeError( - f'additional_kwargs["{k}"] already exists in this message,' - " but with a different type." - ) - elif isinstance(merged[k], str): - merged[k] += v - elif isinstance(merged[k], dict): - merged[k] = self._merge_kwargs_dict(merged[k], v) - elif isinstance(merged[k], list): - merged[k] = merged[k].copy() - for i, e in enumerate(v): - if isinstance(e, dict) and isinstance(e.get("index"), int): - i = e["index"] - if i < len(merged[k]): - merged[k][i] = self._merge_kwargs_dict(merged[k][i], e) - else: - merged[k] = merged[k] + [e] - else: - raise TypeError( - f"Additional kwargs key {k} already exists in this message." - ) - return merged - def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore if isinstance(other, BaseMessageChunk): # If both are (subclasses of) BaseMessageChunk, @@ -170,9 +123,12 @@ class BaseMessageChunk(BaseMessage): return self.__class__( # type: ignore[call-arg] id=self.id, content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( + additional_kwargs=merge_dicts( self.additional_kwargs, other.additional_kwargs ), + response_metadata=merge_dicts( + self.response_metadata, other.response_metadata + ), ) else: raise TypeError( diff --git a/libs/core/langchain_core/messages/chat.py b/libs/core/langchain_core/messages/chat.py index 3c7ed975b65..fad0d265ea6 100644 --- a/libs/core/langchain_core/messages/chat.py +++ b/libs/core/langchain_core/messages/chat.py @@ -5,6 +5,7 @@ from langchain_core.messages.base import ( BaseMessageChunk, merge_content, ) +from langchain_core.utils._merge import merge_dicts class ChatMessage(BaseMessage): @@ -47,17 +48,23 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk): return self.__class__( role=self.role, content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( + additional_kwargs=merge_dicts( self.additional_kwargs, other.additional_kwargs ), + response_metadata=merge_dicts( + self.response_metadata, other.response_metadata + ), ) elif isinstance(other, BaseMessageChunk): return self.__class__( role=self.role, content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( + additional_kwargs=merge_dicts( self.additional_kwargs, other.additional_kwargs ), + response_metadata=merge_dicts( + self.response_metadata, other.response_metadata + ), ) else: return super().__add__(other) diff --git a/libs/core/langchain_core/messages/function.py b/libs/core/langchain_core/messages/function.py index e852aa37276..a98242756a6 100644 --- a/libs/core/langchain_core/messages/function.py +++ b/libs/core/langchain_core/messages/function.py @@ -5,6 +5,7 @@ from langchain_core.messages.base import ( BaseMessageChunk, merge_content, ) +from langchain_core.utils._merge import merge_dicts class FunctionMessage(BaseMessage): @@ -47,9 +48,12 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): return self.__class__( name=self.name, content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( + additional_kwargs=merge_dicts( self.additional_kwargs, other.additional_kwargs ), + response_metadata=merge_dicts( + self.response_metadata, other.response_metadata + ), ) return super().__add__(other) diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index a83894a10f7..9891cab771e 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -5,6 +5,7 @@ from langchain_core.messages.base import ( BaseMessageChunk, merge_content, ) +from langchain_core.utils._merge import merge_dicts class ToolMessage(BaseMessage): @@ -47,9 +48,12 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk): return self.__class__( tool_call_id=self.tool_call_id, content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( + additional_kwargs=merge_dicts( self.additional_kwargs, other.additional_kwargs ), + response_metadata=merge_dicts( + self.response_metadata, other.response_metadata + ), ) return super().__add__(other) diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index 13b91270c70..27dbbdd5ac5 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -16,27 +16,44 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: resulting in merged = {"function_call": {"arguments": "{\n"}}. """ merged = left.copy() - for k, v in right.items(): - if k not in merged: - merged[k] = v - elif v is not None and merged[k] is None: - merged[k] = v - elif v is None or merged[k] == v: + for right_k, right_v in right.items(): + if right_k not in merged: + merged[right_k] = right_v + elif right_v is not None and merged[right_k] is None: + merged[right_k] = right_v + elif right_v is None: continue - elif type(merged[k]) != type(v): + elif type(merged[right_k]) != type(right_v): raise TypeError( - f'additional_kwargs["{k}"] already exists in this message,' + f'additional_kwargs["{right_k}"] already exists in this message,' " but with a different type." ) - elif isinstance(merged[k], str): - merged[k] += v - elif isinstance(merged[k], dict): - merged[k] = merge_dicts(merged[k], v) - elif isinstance(merged[k], list): - merged[k] = merged[k] + v + elif isinstance(merged[right_k], str): + merged[right_k] += right_v + elif isinstance(merged[right_k], dict): + merged[right_k] = merge_dicts(merged[right_k], right_v) + elif isinstance(merged[right_k], list): + merged[right_k] = merged[right_k].copy() + for e in right_v: + if isinstance(e, dict) and "index" in e and isinstance(e["index"], int): + to_merge = [ + i + for i, e_left in enumerate(merged[right_k]) + if e_left["index"] == e["index"] + ] + if to_merge: + merged[right_k][to_merge[0]] = merge_dicts( + merged[right_k][to_merge[0]], e + ) + else: + merged[right_k] = merged[right_k] + [e] + else: + merged[right_k] = merged[right_k] + [e] + elif merged[right_k] == right_v: + continue else: raise TypeError( - f"Additional kwargs key {k} already exists in left dict and value has " - f"unsupported type {type(merged[k])}." + f"Additional kwargs key {right_k} already exists in left dict and " + f"value has unsupported type {type(merged[right_k])}." ) return merged diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index a69ea2510b4..9cffdf80196 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -48,9 +48,9 @@ def test_check_package_version( ({"a": 1.5}, {"a": 1.5}, {"a": 1.5}), ({"a": True}, {"a": True}, {"a": True}), ({"a": False}, {"a": False}, {"a": False}), - ({"a": "txt"}, {"a": "txt"}, {"a": "txt"}), - ({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2]}), - ({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txt"}}), + ({"a": "txt"}, {"a": "txt"}, {"a": "txttxt"}), + ({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2, 1, 2]}), + ({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txttxt"}}), # Merge strings. ({"a": "one"}, {"a": "two"}, {"a": "onetwo"}), # Merge dicts. @@ -89,6 +89,17 @@ def test_check_package_version( ), ), ), + # 'index' keyword has special handling + ( + {"a": [{"index": 0, "b": "{"}]}, + {"a": [{"index": 0, "b": "f"}]}, + {"a": [{"index": 0, "b": "{f"}]}, + ), + ( + {"a": [{"idx": 0, "b": "{"}]}, + {"a": [{"idx": 0, "b": "f"}]}, + {"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]}, + ), ), ) def test_merge_dicts( diff --git a/libs/langchain/tests/unit_tests/test_cache.py b/libs/langchain/tests/unit_tests/test_cache.py index 5b2e1a4da4e..42cc6c36b68 100644 --- a/libs/langchain/tests/unit_tests/test_cache.py +++ b/libs/langchain/tests/unit_tests/test_cache.py @@ -14,10 +14,7 @@ from langchain_core.outputs import ChatGeneration, Generation from sqlalchemy import create_engine from sqlalchemy.orm import Session -from langchain.cache import ( - InMemoryCache, - SQLAlchemyCache, -) +from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.globals import get_llm_cache, set_llm_cache @@ -67,7 +64,7 @@ async def test_llm_caching() -> None: llm_string=create_llm_string(llm), return_val=[Generation(text=cached_response)], ) - assert llm(prompt) == cached_response + assert llm.invoke(prompt) == cached_response # async test await llm_cache.aupdate( prompt=prompt, diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index 6cb6ec95f3d..24a5e26ee79 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -1,10 +1,10 @@ """Test AzureChatOpenAI wrapper.""" import os -from typing import Any +from typing import Any, Optional import pytest from langchain_core.callbacks import CallbackManager -from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_openai import AzureChatOpenAI @@ -164,16 +164,20 @@ async def test_async_chat_openai_streaming() -> None: @pytest.mark.scheduled def test_openai_streaming(llm: AzureChatOpenAI) -> None: """Test streaming tokens from OpenAI.""" - - for token in llm.stream("I'm Pickle Rick"): - assert isinstance(token.content, str) + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream("I'm Pickle Rick"): + assert isinstance(chunk.content, str) + full = chunk if full is None else full + chunk @pytest.mark.scheduled async def test_openai_astream(llm: AzureChatOpenAI) -> None: """Test streaming tokens from OpenAI.""" - async for token in llm.astream("I'm Pickle Rick"): - assert isinstance(token.content, str) + + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream("I'm Pickle Rick"): + assert isinstance(chunk.content, str) + full = chunk if full is None else full + chunk @pytest.mark.scheduled diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 33006a624cf..462e2531cf6 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1,9 +1,15 @@ """Test ChatOpenAI chat model.""" -from typing import Any, Optional +from typing import Any, Optional, cast import pytest from langchain_core.callbacks import CallbackManager -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import ( + AIMessage, + BaseMessage, + BaseMessageChunk, + HumanMessage, + SystemMessage, +) from langchain_core.outputs import ( ChatGeneration, ChatResult, @@ -99,7 +105,7 @@ def test_chat_openai_streaming() -> None: verbose=True, ) message = HumanMessage(content="Hello") - response = chat([message]) + response = chat.invoke([message]) assert callback_handler.llm_streams > 0 assert isinstance(response, BaseMessage) @@ -336,16 +342,20 @@ def test_stream() -> None: """Test streaming tokens from OpenAI.""" llm = ChatOpenAI() - for token in llm.stream("I'm Pickle Rick"): - assert isinstance(token.content, str) + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream("I'm Pickle Rick"): + assert isinstance(chunk.content, str) + full = chunk if full is None else full + chunk async def test_astream() -> None: """Test streaming tokens from OpenAI.""" llm = ChatOpenAI() - async for token in llm.astream("I'm Pickle Rick"): - assert isinstance(token.content, str) + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream("I'm Pickle Rick"): + assert isinstance(chunk.content, str) + full = chunk if full is None else full + chunk async def test_abatch() -> None: @@ -395,33 +405,33 @@ def test_invoke() -> None: def test_logprobs() -> None: llm = ChatOpenAI() - result = llm.generate([[HumanMessage(content="I'm PickleRick")]], logprobs=True) - assert result.generations[0][0].generation_info - assert "content" in result.generations[0][0].generation_info["logprobs"] + result = llm.invoke([HumanMessage(content="I'm PickleRick")], logprobs=True) + assert result.response_metadata + assert "content" in result.response_metadata["logprobs"] async def test_async_logprobs() -> None: llm = ChatOpenAI() - result = await llm.agenerate( - [[HumanMessage(content="I'm PickleRick")]], logprobs=True - ) - assert result.generations[0][0].generation_info - assert "content" in result.generations[0][0].generation_info["logprobs"] + result = await llm.ainvoke([HumanMessage(content="I'm PickleRick")], logprobs=True) + assert result.response_metadata + assert "content" in result.response_metadata["logprobs"] def test_logprobs_streaming() -> None: llm = ChatOpenAI() - result = llm.generate( - [[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True - ) - assert result.generations[0][0].generation_info - assert "content" in result.generations[0][0].generation_info["logprobs"] + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream("I'm Pickle Rick", logprobs=True): + assert isinstance(chunk.content, str) + full = chunk if full is None else full + chunk + assert cast(BaseMessageChunk, full).response_metadata + assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] async def test_async_logprobs_streaming() -> None: llm = ChatOpenAI() - result = await llm.agenerate( - [[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True - ) - assert result.generations[0][0].generation_info - assert "content" in result.generations[0][0].generation_info["logprobs"] + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream("I'm Pickle Rick", logprobs=True): + assert isinstance(chunk.content, str) + full = chunk if full is None else full + chunk + assert cast(BaseMessageChunk, full).response_metadata + assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]