mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 15:35:14 +00:00
(WIP) add HyDE (#393)
Co-authored-by: cameronccohen <cameron.c.cohen@gmail.com> Co-authored-by: Cameron Cohen <cameron.cohen@quantco.com>
This commit is contained in:
57
tests/unit_tests/test_hyde.py
Normal file
57
tests/unit_tests/test_hyde.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Test HyDE."""
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
||||
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
||||
from langchain.llms.base import BaseLLM, LLMResult
|
||||
from langchain.schema import Generation
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embedding class for tests."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return random floats."""
|
||||
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return random floats."""
|
||||
return list(np.random.uniform(0, 1, 10))
|
||||
|
||||
|
||||
class FakeLLM(BaseLLM, BaseModel):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
n: int = 1
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fake"
|
||||
|
||||
|
||||
def test_hyde_from_llm() -> None:
|
||||
"""Test loading HyDE from all prompts."""
|
||||
for key in PROMPT_MAP:
|
||||
embedding = HypotheticalDocumentEmbedder.from_llm(
|
||||
FakeLLM(), FakeEmbeddings(), key
|
||||
)
|
||||
embedding.embed_query("foo")
|
||||
|
||||
|
||||
def test_hyde_from_llm_with_multiple_n() -> None:
|
||||
"""Test loading HyDE from all prompts."""
|
||||
for key in PROMPT_MAP:
|
||||
embedding = HypotheticalDocumentEmbedder.from_llm(
|
||||
FakeLLM(n=8), FakeEmbeddings(), key
|
||||
)
|
||||
embedding.embed_query("foo")
|
Reference in New Issue
Block a user