mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 02:03:44 +00:00
Add OpenAIEmbeddings special token params for tiktoken (#2682)
#2681
Original type hints
```python
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
```
from
46287bfa49/tiktoken/core.py (L79-L80)
are not compatible with pydantic
<img width="718" alt="image"
src="https://user-images.githubusercontent.com/5096640/230993236-c744940e-85fb-4baa-b9da-8b00fb60a2a8.png">
I think we could use
```python
allowed_special: Union[Literal["all"], Set[str]] = set()
disallowed_special: Union[Literal["all"], Set[str], Tuple[()]] = "all"
```
Please let me know if you would like to implement it differently.
This commit is contained in:
parent
1c979e320d
commit
023de9a70b
@ -2,7 +2,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
@ -99,6 +109,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
embedding_ctx_length: int = 8191
|
embedding_ctx_length: int = 8191
|
||||||
openai_api_key: Optional[str] = None
|
openai_api_key: Optional[str] = None
|
||||||
openai_organization: Optional[str] = None
|
openai_organization: Optional[str] = None
|
||||||
|
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||||
|
disallowed_special: Union[Literal["all"], Set[str], Tuple[()]] = "all"
|
||||||
chunk_size: int = 1000
|
chunk_size: int = 1000
|
||||||
"""Maximum number of texts to embed in each batch"""
|
"""Maximum number of texts to embed in each batch"""
|
||||||
max_retries: int = 6
|
max_retries: int = 6
|
||||||
@ -195,7 +207,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
# replace newlines, which can negatively affect performance.
|
# replace newlines, which can negatively affect performance.
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
token = encoding.encode(text)
|
token = encoding.encode(
|
||||||
|
text,
|
||||||
|
allowed_special=self.allowed_special,
|
||||||
|
disallowed_special=self.disallowed_special,
|
||||||
|
)
|
||||||
for j in range(0, len(token), self.embedding_ctx_length):
|
for j in range(0, len(token), self.embedding_ctx_length):
|
||||||
tokens += [token[j : j + self.embedding_ctx_length]]
|
tokens += [token[j : j + self.embedding_ctx_length]]
|
||||||
indices += [i]
|
indices += [i]
|
||||||
|
Loading…
Reference in New Issue
Block a user