synthetic-data: linting and typing additions

This commit is contained in:
PaperMoose
2023-09-08 17:45:48 -07:00
parent 86430a41b7
commit bb9fc2105a
4 changed files with 144 additions and 94 deletions

View File

@@ -1,8 +1,7 @@
import asyncio
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic.class_validators import root_validator
from pydantic.error_wrappers import ValidationError
from pydantic.main import BaseModel
from langchain.chains.base import Chain
@@ -12,8 +11,7 @@ from langchain.schema.language_model import BaseLanguageModel
class SyntheticDataGenerator(BaseModel):
"""
Generates synthetic data using the given LLM and few-shot template.
"""Generates synthetic data using the given LLM and few-shot template.
Utilizes the provided LLM to produce synthetic data based on the
few-shot prompt template.
@@ -21,7 +19,7 @@ class SyntheticDataGenerator(BaseModel):
Attributes:
template (FewShotPromptTemplate): Template for few-shot prompting.
llm (Optional[BaseLanguageModel]): Large Language Model to use for generation.
llm_chain (Optional[Chain]): LLM chain initialized with the LLM and few-shot template.
llm_chain (Optional[Chain]): LLM chain with the LLM and few-shot template.
example_input_key (str): Key to use for storing example inputs.
Usage Example:
@@ -41,37 +39,43 @@ class SyntheticDataGenerator(BaseModel):
validate_assignment = True
@root_validator(pre=False, skip_on_failure=True)
def set_llm_chain(cls, values):
def set_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm_chain = values.get("llm_chain")
llm = values.get("llm")
few_shot_template = values.get("template")
if not llm_chain: # If llm_chain is None or not present
if llm is None or few_shot_template is None:
raise ValidationError("Both llm and few_shot_template must be provided if llm_chain is not given.")
raise ValueError(
"Both llm and few_shot_template must be provided if llm_chain is "
"not given."
)
values["llm_chain"] = LLMChain(llm=llm, prompt=few_shot_template)
return values
@staticmethod
def _format_dict_to_string(input_dict: dict) -> str:
formatted_str = ', '.join([f"{key}: {value}" for key, value in input_dict.items()])
def _format_dict_to_string(input_dict: Dict) -> str:
formatted_str = ", ".join(
[f"{key}: {value}" for key, value in input_dict.items()]
)
return formatted_str
def _update_examples(self, example):
"""Prevents duplicates by adding previously generated examples to the few shot list"""
if isinstance(example, BaseModel):
formatted_example = self._format_dict_to_string(example.dict())
elif isinstance(example, dict):
formatted_example = self._format_dict_to_string(example)
else:
formatted_example = str(example)
self.template.examples.pop(0)
self.template.examples.append({self.example_input_key: formatted_example})
def _update_examples(self, example: Union[BaseModel, Dict[str, Any], str]) -> None:
"""Prevents duplicates by adding previously generated examples to the few shot
list."""
if self.template and self.template.examples:
if isinstance(example, BaseModel):
formatted_example = self._format_dict_to_string(example.dict())
elif isinstance(example, dict):
formatted_example = self._format_dict_to_string(example)
else:
formatted_example = str(example)
self.template.examples.pop(0)
self.template.examples.append({self.example_input_key: formatted_example})
def generate(self, subject: str, runs: int, **kwargs) -> List[str]:
"""
Generate synthetic data using the given subject string.
def generate(self, subject: str, runs: int, *args: Any, **kwargs: Any) -> List[str]:
"""Generate synthetic data using the given subject string.
Args:
subject (str): The subject the synthetic data will be about.
@@ -82,19 +86,27 @@ class SyntheticDataGenerator(BaseModel):
List[str]: List of generated synthetic data.
Usage Example:
>>> results = generator.generate(subject="climate change", runs=5, extra="Focus on environmental impacts.")
>>> results = generator.generate(subject="climate change", runs=5,
extra="Focus on environmental impacts.")
"""
if self.llm_chain is None:
raise ValueError(
"llm_chain is none, either set either llm_chain or llm at generator "
"construction"
)
for _ in range(runs):
result = self.llm_chain.run(subject=subject, **kwargs)
result = self.llm_chain.run(subject=subject, *args, **kwargs)
self.results.append(result)
self._update_examples(result)
return self.results
async def agenerate(self, subject: str, runs: int, extra: str = "", **kwargs) -> List[str]:
"""
Generate synthetic data using the given subject asynchronously.
async def agenerate(
self, subject: str, runs: int, extra: str = "", *args: Any, **kwargs: Any
) -> List[str]:
"""Generate synthetic data using the given subject asynchronously.
Note: Since the LLM calls run concurrently, you may have fewer duplicates by adding specific instructions to
Note: Since the LLM calls run concurrently,
you may have fewer duplicates by adding specific instructions to
the "extra" keyword argument.
Args:
@@ -106,12 +118,20 @@ class SyntheticDataGenerator(BaseModel):
List[str]: List of generated synthetic data for the given subject.
Usage Example:
>>> results = await generator.agenerate(subject="climate change", runs=5, extra="Focus on env impacts.")
>>> results = await generator.agenerate(subject="climate change", runs=5,
extra="Focus on env impacts.")
"""
async def run_chain(subject: str, extra: str = "", **kwargs):
result = await self.llm_chain.arun(subject=subject, extra=extra, **kwargs)
self.results.append(result)
async def run_chain(
subject: str, extra: str = "", *args: Any, **kwargs: Any
) -> None:
if self.llm_chain is not None:
result = await self.llm_chain.arun(
subject=subject, extra=extra, *args, **kwargs
)
self.results.append(result)
await asyncio.gather(*(run_chain(subject=subject, extra=extra, **kwargs) for _ in range(runs)))
await asyncio.gather(
*(run_chain(subject=subject, extra=extra) for _ in range(runs))
)
return self.results

View File

@@ -1,52 +1,63 @@
from typing import Optional, Any, Dict, Type, Union
from typing import Any, Dict, Optional, Type, Union
from pydantic.main import BaseModel
from langchain import BasePromptTemplate, PromptTemplate
from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chains.data_generation.base import SyntheticDataGenerator
from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseLLMOutputParser
from langchain.schema.language_model import BaseLanguageModel
OPENAI_TEMPLATE = PromptTemplate(
input_variables=["example"], template="{example}"
)
OPENAI_TEMPLATE = PromptTemplate(input_variables=["example"], template="{example}")
def create_openai_data_generator(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: ChatOpenAI,
prompt: BasePromptTemplate,
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any
output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: ChatOpenAI,
prompt: BasePromptTemplate,
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any
) -> SyntheticDataGenerator:
"""
Create an instance of SyntheticDataGenerator tailored for OpenAI models.
This function creates an LLM chain designed for structured output based on the provided schema,
language model, and prompt template. The resulting chain is then used to instantiate and return
a SyntheticDataGenerator.
This function creates an LLM chain designed for structured output based on the
provided schema, language model, and prompt template. The resulting chain is then
used to instantiate and return a SyntheticDataGenerator.
Args:
output_schema (Union[Dict[str, Any], Type[BaseModel]]): Schema for expected output. This can be either
a dictionary representing a valid JsonSchema or a Pydantic BaseModel class.
llm (ChatOpenAI): OpenAI language model to use.
prompt (BasePromptTemplate): Template to be used for generating prompts.
output_parser (Optional[BaseLLMOutputParser], optional): Parser for processing model outputs. If none
is provided, a default will be inferred from the function types.
**kwargs: Additional keyword arguments to be passed to `create_structured_output_chain`.
output_schema (Union[Dict[str, Any], Type[BaseModel]]): Schema for expected
output. This can be either a dictionary representing a valid JsonSchema or a
Pydantic BaseModel class.
Returns:
SyntheticDataGenerator: An instance of the data generator set up with the constructed chain.
llm (ChatOpenAI): OpenAI language model to use.
prompt (BasePromptTemplate): Template to be used for generating prompts.
output_parser (Optional[BaseLLMOutputParser], optional): Parser for
processing model outputs. If none is provided, a default will be inferred
from the function types.
**kwargs: Additional keyword arguments to be passed to
`create_structured_output_chain`.
Returns: SyntheticDataGenerator: An instance of the data generator set up with
the constructed chain.
Usage:
To generate synthetic data with a structured output, first define your desired output schema. Then,
use this function to create a SyntheticDataGenerator instance. After obtaining the generator, you
can utilize its methods to produce the desired synthetic data.
To generate synthetic data with a structured output, first define your desired
output schema. Then, use this function to create a SyntheticDataGenerator
instance. After obtaining the generator, you can utilize its methods to produce
the desired synthetic data.
"""
# Create function calling chain to ensure structured output
chain = create_structured_output_chain(output_schema, llm, prompt, output_parser=output_parser, **kwargs)
chain = create_structured_output_chain(
output_schema, llm, prompt, output_parser=output_parser, **kwargs
)
# Create the SyntheticDataGenerator instance with the created chain
generator = SyntheticDataGenerator(template=prompt, llm_chain=chain)

View File

@@ -5,7 +5,9 @@ DEFAULT_PROMPT = PromptTemplate(
input_variables=[DEFAULT_INPUT_KEY], template="{example}"
)
SYNTHETIC_FEW_SHOT_PREFIX = "This is a test about generating synthetic data about {subject}. Examples below:"
SYNTHETIC_FEW_SHOT_SUFFIX = """Now you generate synthetic data about {subject}. Make sure to {extra}:"""
SYNTHETIC_FEW_SHOT_PREFIX = (
"This is a test about generating synthetic data about {subject}. Examples below:"
)
SYNTHETIC_FEW_SHOT_SUFFIX = (
"""Now you generate synthetic data about {subject}. Make sure to {extra}:"""
)

View File

@@ -1,11 +1,17 @@
import pytest
from pydantic import BaseModel
from langchain import FewShotPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains.data_generation.base import SyntheticDataGenerator
from langchain.chains.data_generation.openai import create_openai_data_generator, OPENAI_TEMPLATE
from langchain.chains.data_generation.prompts import SYNTHETIC_FEW_SHOT_PREFIX, SYNTHETIC_FEW_SHOT_SUFFIX
from pydantic import BaseModel
from langchain.chains.data_generation.openai import (
OPENAI_TEMPLATE,
create_openai_data_generator,
)
from langchain.chains.data_generation.prompts import (
SYNTHETIC_FEW_SHOT_PREFIX,
SYNTHETIC_FEW_SHOT_SUFFIX,
)
from langchain.chat_models import ChatOpenAI
# Define the desired output schema for individual medical billing record
@@ -20,29 +26,35 @@ class MedicalBilling(BaseModel):
examples = [
{
"example": """Patient ID: 123456, Patient Name: John Doe, Diagnosis Code: J20.9, Procedure Code: 99203,
Total Charge: $500, Insurance Claim Amount: $350"""
"example": """Patient ID: 123456, Patient Name: John Doe, Diagnosis Code:
J20.9, Procedure Code: 99203, Total Charge: $500, Insurance Claim Amount:
$350"""
},
{
"example": """Patient ID: 789012, Patient Name: Johnson Smith, Diagnosis Code: M54.5, Procedure Code: 99213,
Total Charge: $150, Insurance Claim Amount: $120"""
"example": """Patient ID: 789012, Patient Name: Johnson Smith, Diagnosis
Code: M54.5, Procedure Code: 99213, Total Charge: $150, Insurance Claim
Amount: $120"""
},
{
"example": """Patient ID: 345678, Patient Name: Emily Stone, Diagnosis Code: E11.9, Procedure Code: 99214,
Total Charge: $300, Insurance Claim Amount: $250"""
"example": """Patient ID: 345678, Patient Name: Emily Stone, Diagnosis Code:
E11.9, Procedure Code: 99214, Total Charge: $300, Insurance Claim Amount:
$250"""
},
{
"example": """Patient ID: 901234, Patient Name: Robert Miles, Diagnosis Code: B07.9, Procedure Code: 99204,
Total Charge: $200, Insurance Claim Amount: $160"""
"example": """Patient ID: 901234, Patient Name: Robert Miles, Diagnosis Code:
B07.9, Procedure Code: 99204, Total Charge: $200, Insurance Claim Amount:
$160"""
},
{
"example": """Patient ID: 567890, Patient Name: Clara Jensen, Diagnosis Code: F41.9, Procedure Code: 99205,
Total Charge: $450, Insurance Claim Amount: $310"""
"example": """Patient ID: 567890, Patient Name: Clara Jensen, Diagnosis Code:
F41.9, Procedure Code: 99205, Total Charge: $450, Insurance Claim Amount:
$310"""
},
{
"example": """Patient ID: 234567, Patient Name: Alan Turing, Diagnosis Code: G40.909, Procedure Code: 99215,
Total Charge: $220, Insurance Claim Amount: $180"""
}
"example": """Patient ID: 234567, Patient Name: Alan Turing, Diagnosis Code:
G40.909, Procedure Code: 99215, Total Charge: $220, Insurance Claim Amount:
$180"""
},
]
prompt_template = FewShotPromptTemplate(
@@ -55,21 +67,22 @@ prompt_template = FewShotPromptTemplate(
@pytest.fixture(scope="function")
def synthetic_data_generator():
def synthetic_data_generator() -> SyntheticDataGenerator:
return create_openai_data_generator(
output_schema=MedicalBilling,
llm=ChatOpenAI(temperature=1), # replace with your LLM instance
prompt=prompt_template
prompt=prompt_template,
)
@pytest.mark.requires("openai")
def test_generate_synthetic(synthetic_data_generator: SyntheticDataGenerator):
synthetic_results = synthetic_data_generator.generate(subject="medical_billing",
extra="""the name must be chosen at random. Make it
something you wouldn't normally choose. The CPT
codes must make sense with the ICD-10 code""",
runs=10)
def test_generate_synthetic(synthetic_data_generator: SyntheticDataGenerator) -> None:
synthetic_results = synthetic_data_generator.generate(
subject="medical_billing",
extra="""the name must be chosen at random. Make it something you wouldn't
normally choose.""",
runs=10,
)
assert len(synthetic_results) == 10
for row in synthetic_results:
assert isinstance(row, MedicalBilling)
@@ -78,11 +91,15 @@ def test_generate_synthetic(synthetic_data_generator: SyntheticDataGenerator):
@pytest.mark.requires("openai")
@pytest.mark.asyncio
async def test_agenerate_synthetic(synthetic_data_generator: SyntheticDataGenerator):
synthetic_results = await synthetic_data_generator.agenerate(subject="medical_billing",
extra="""the name must be chosen at random. Make it
something you wouldn't normally choose.""",
runs=10)
async def test_agenerate_synthetic(
synthetic_data_generator: SyntheticDataGenerator,
) -> None:
synthetic_results = await synthetic_data_generator.agenerate(
subject="medical_billing",
extra="""the name must be chosen at random. Make it something you wouldn't
normally choose.""",
runs=10,
)
assert len(synthetic_results) == 10
for row in synthetic_results:
assert isinstance(row, MedicalBilling)