(WIP) add HyDE (#393)

Co-authored-by: cameronccohen <cameron.c.cohen@gmail.com>
Co-authored-by: Cameron Cohen <cameron.cohen@quantco.com>
This commit is contained in:
Harrison Chase
2022-12-21 20:46:41 -05:00
committed by GitHub
parent 543db9c2df
commit 6b60c509ac
8 changed files with 421 additions and 8 deletions

View File

@@ -0,0 +1,57 @@
"""Test HyDE."""
from typing import List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM, LLMResult
from langchain.schema import Generation
class FakeEmbeddings(Embeddings):
"""Fake embedding class for tests."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return random floats."""
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
def embed_query(self, text: str) -> List[float]:
"""Return random floats."""
return list(np.random.uniform(0, 1, 10))
class FakeLLM(BaseLLM, BaseModel):
"""Fake LLM wrapper for testing purposes."""
n: int = 1
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
def test_hyde_from_llm() -> None:
"""Test loading HyDE from all prompts."""
for key in PROMPT_MAP:
embedding = HypotheticalDocumentEmbedder.from_llm(
FakeLLM(), FakeEmbeddings(), key
)
embedding.embed_query("foo")
def test_hyde_from_llm_with_multiple_n() -> None:
"""Test loading HyDE from all prompts."""
for key in PROMPT_MAP:
embedding = HypotheticalDocumentEmbedder.from_llm(
FakeLLM(n=8), FakeEmbeddings(), key
)
embedding.embed_query("foo")