From ac9609f58f3a6827c754d03392a31c8a77744f7c Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Wed, 13 Sep 2023 16:49:16 -0400 Subject: [PATCH] fix: unify generation outputs on newer openllm release (#10523) update newer generation format from OpenLLm where it returns a dictionary for one shot generation cc @baskaryan Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- libs/langchain/langchain/llms/openllm.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/llms/openllm.py b/libs/langchain/langchain/llms/openllm.py index df8d4bc3818..d0d70f1494f 100644 --- a/libs/langchain/langchain/llms/openllm.py +++ b/libs/langchain/langchain/llms/openllm.py @@ -265,10 +265,16 @@ class OpenLLM(LLM): self._identifying_params["model_name"], **copied ) if self._client: - return self._client.query(prompt, **config.model_dump(flatten=True)) + o = self._client.query(prompt, **config.model_dump(flatten=True)) + if isinstance(o, dict) and "text" in o: + return o["text"] + return o else: assert self._runner is not None - return self._runner(prompt, **config.model_dump(flatten=True)) + o = self._runner(prompt, **config.model_dump(flatten=True)) + if isinstance(o, dict) and "text" in o: + return o["text"] + return o async def _acall( self, @@ -291,9 +297,12 @@ class OpenLLM(LLM): self._identifying_params["model_name"], **copied ) if self._client: - return await self._client.acall( + o = await self._client.acall( "generate", prompt, **config.model_dump(flatten=True) ) + if isinstance(o, dict) and "text" in o: + return o["text"] + return o else: assert self._runner is not None ( @@ -304,6 +313,9 @@ class OpenLLM(LLM): generated_result = await self._runner.generate.async_run( prompt, **generate_kwargs ) - return self._runner.llm.postprocess_generate( + o = self._runner.llm.postprocess_generate( prompt, generated_result, **postprocess_kwargs ) + if isinstance(o, dict) and "text" in o: + return o["text"] + return o