mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
Add OpenAIEmbeddings special token params for tiktoken (#2682)
#2681
Original type hints
```python
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
```
from
46287bfa49/tiktoken/core.py (L79-L80)
are not compatible with pydantic
<img width="718" alt="image"
src="https://user-images.githubusercontent.com/5096640/230993236-c744940e-85fb-4baa-b9da-8b00fb60a2a8.png">
I think we could use
```python
allowed_special: Union[Literal["all"], Set[str]] = set()
disallowed_special: Union[Literal["all"], Set[str], Tuple[()]] = "all"
```
Please let me know if you would like to implement it differently.
This commit is contained in:
parent
1c979e320d
commit
023de9a70b
@ -2,7 +2,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
@ -99,6 +109,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
embedding_ctx_length: int = 8191
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_organization: Optional[str] = None
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||
disallowed_special: Union[Literal["all"], Set[str], Tuple[()]] = "all"
|
||||
chunk_size: int = 1000
|
||||
"""Maximum number of texts to embed in each batch"""
|
||||
max_retries: int = 6
|
||||
@ -195,7 +207,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
for i, text in enumerate(texts):
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
token = encoding.encode(text)
|
||||
token = encoding.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
for j in range(0, len(token), self.embedding_ctx_length):
|
||||
tokens += [token[j : j + self.embedding_ctx_length]]
|
||||
indices += [i]
|
||||
|
Loading…
Reference in New Issue
Block a user