mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
fix vertexai codey models (#12173)
**Description:** This PR fixes issue #12156 by checking for Codey models appropriately before result parsing. Maintainer: @hwchase17 , @agola11
This commit is contained in:
parent
386ea48432
commit
a2840a2b42
@ -299,7 +299,12 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
res = completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
)
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
if self.is_codey_model:
|
||||
generations.append([_response_to_generation(res)])
|
||||
else:
|
||||
generations.append(
|
||||
[_response_to_generation(r) for r in res.candidates]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
|
@ -36,6 +36,15 @@ def test_vertex_generate() -> None:
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertex_generate_code() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001")
|
||||
output = llm.generate(["generate a python method that says foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_agenerate() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user