Compare commits

...

2 Commits

Author SHA1 Message Date
William Fu-Hinthorn
e08cf19a0c Merge branch 'master' into wfh/add_llm_output_to_adapter 2023-11-07 11:19:37 -08:00
William Fu-Hinthorn
2153f924d6 Add LLM Output to Adapter Result 2023-11-07 09:24:36 -08:00
2 changed files with 24 additions and 4 deletions

View File

@@ -100,7 +100,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
return {**message.additional_kwargs, **message_dict}
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
@@ -171,12 +171,18 @@ class ChatCompletion:
**kwargs: Any,
) -> Union[dict, Iterable]:
models = importlib.import_module("langchain.chat_models")
callbacks = importlib.import_module("langchain.callbacks.manager")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
with callbacks.collect_runs() as cb:
result = model_config.invoke(converted_messages, {"callbacks": [cb]})
run = cb.traced_runs[0]
return {
**(run.outputs.get("llm_output") or {}),
"choices": [{"message": convert_message_to_dict(result)}],
}
else:
return (
_convert_message_chunk_to_delta(c, i)

View File

@@ -345,6 +345,7 @@ class ChatOpenAI(BaseChatModel):
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
combined = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
@@ -357,7 +358,14 @@ class ChatOpenAI(BaseChatModel):
overall_token_usage[k] = v
if system_fingerprint is None:
system_fingerprint = output.get("system_fingerprint")
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
for k, v in output.items():
if k in {"token_usage", "system_fingerprint"}:
continue
if k not in combined:
combined[k] = v
combined.update(
{"token_usage": overall_token_usage, "model_name": self.model_name}
)
if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint
return combined
@@ -438,10 +446,16 @@ class ChatOpenAI(BaseChatModel):
)
generations.append(gen)
token_usage = response.get("usage", {})
output_kwargs = {
k: v
for k, v in response.items()
if k not in {"choices", "system_fingerprint", "usage"}
}
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
**output_kwargs,
}
return ChatResult(generations=generations, llm_output=llm_output)