Use serialized format for messages in tracer (#6827)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @dev2049
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @dev2049
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @vowelparrot
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
This commit is contained in:
Nuno Campos 2023-07-04 10:19:08 +01:00 committed by GitHub
parent 0b69a7e9ab
commit 696886f397
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 11 deletions

View File

@ -4,12 +4,14 @@ from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.schema import Document, LLMResult from langchain.load.dump import dumpd
from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -143,6 +145,13 @@ class BaseTracer(BaseCallbackHandler, ABC):
if llm_run is None or llm_run.run_type != RunTypeEnum.llm: if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
raise TracerException("No LLM Run found to be traced") raise TracerException("No LLM Run found to be traced")
llm_run.outputs = response.dict() llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j]
if "message" in output_generation:
output_generation["message"] = dumpd(
cast(ChatGeneration, generation).message
)
llm_run.end_time = datetime.utcnow() llm_run.end_time = datetime.utcnow()
llm_run.events.append({"name": "end", "time": llm_run.end_time}) llm_run.events.append({"name": "end", "time": llm_run.end_time})
self._end_trace(llm_run) self._end_trace(llm_run)

View File

@ -11,13 +11,10 @@ from uuid import UUID
from langchainplus_sdk import LangChainPlusClient from langchainplus_sdk import LangChainPlusClient
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import ( from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
Run,
RunTypeEnum,
TracerSession,
)
from langchain.env import get_runtime_environment from langchain.env import get_runtime_environment
from langchain.schema.messages import BaseMessage, messages_to_dict from langchain.load.dump import dumpd
from langchain.schema.messages import BaseMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_LOGGED = set() _LOGGED = set()
@ -83,7 +80,7 @@ class LangChainTracer(BaseTracer):
id=run_id, id=run_id,
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
serialized=serialized, serialized=serialized,
inputs={"messages": [messages_to_dict(batch) for batch in messages]}, inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
extra=kwargs, extra=kwargs,
events=[{"name": "start", "time": start_time}], events=[{"name": "start", "time": start_time}],
start_time=start_time, start_time=start_time,

View File

@ -15,8 +15,8 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_openai_call() -> None: def test_openai_call() -> None:
"""Test valid call to openai.""" """Test valid call to openai."""
llm = OpenAI(max_tokens=10) llm = OpenAI(max_tokens=10, n=3)
output = llm("Say foo:") output = llm("Say something nice:")
assert isinstance(output, str) assert isinstance(output, str)