Skip to content

Commit 89153d7

Browse files
committed
Sync codebase
1 parent 55c8d83 commit 89153d7

File tree

5 files changed

+53
-16
lines changed

5 files changed

+53
-16
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
This is the changelog for the open source version of tiktoken.
44

5+
## [v0.6.0]
6+
- Optimise regular expressions for a 20% performance improvement
7+
- Add `text-embedding-3-*` models to `encoding_for_model`
8+
- Check content hash for downloaded files
9+
- Allow pickling `Encoding` objects. Registered `Encoding` will be pickled by reference
10+
- Workaround PyO3 bug for frozenset conversion
11+
12+
Thank you to @paplorinc, @mdwelsh, @Praneet460!
13+
514
## [v0.5.2]
615
- Build wheels for Python 3.12
716
- Update version of PyO3 to allow multiple imports

Cargo.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "tiktoken"
3-
version = "0.5.2"
3+
version = "0.6.0"
44
edition = "2021"
55
rust-version = "1.57.0"
66

@@ -16,6 +16,3 @@ fancy-regex = "0.11.0"
1616
regex = "1.8.3"
1717
rustc-hash = "1.1.0"
1818
bstr = "1.5.0"
19-
20-
[profile.release]
21-
incremental = true

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "tiktoken"
3-
version = "0.5.2"
3+
version = "0.6.0"
44
description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
55
readme = "README.md"
66
license = {file = "LICENSE"}

tiktoken/core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def encode(
116116
if match := _special_token_regex(disallowed_special).search(text):
117117
raise_disallowed_special_token(match.group())
118118

119+
# https://github.com/PyO3/pyo3/pull/3632
120+
if isinstance(allowed_special, frozenset):
121+
allowed_special = set(allowed_special)
122+
119123
try:
120124
return self._core_bpe.encode(text, allowed_special)
121125
except UnicodeEncodeError:
@@ -364,6 +368,26 @@ def _encode_only_native_bpe(self, text: str) -> list[int]:
364368
def _encode_bytes(self, text: bytes) -> list[int]:
365369
return self._core_bpe._encode_bytes(text)
366370

371+
def __getstate__(self) -> object:
372+
import tiktoken.registry
373+
374+
# As an optimisation, pickle registered encodings by reference
375+
if self is tiktoken.registry.ENCODINGS.get(self.name):
376+
return self.name
377+
return {
378+
"name": self.name,
379+
"pat_str": self._pat_str,
380+
"mergeable_ranks": self._mergeable_ranks,
381+
"special_tokens": self._special_tokens,
382+
}
383+
384+
def __setstate__(self, value: object) -> None:
385+
import tiktoken.registry
386+
387+
if isinstance(value, str):
388+
self.__dict__ = tiktoken.registry.get_encoding(value).__dict__
389+
return
390+
self.__init__(**value)
367391

368392

369393
@functools.lru_cache(maxsize=128)

tiktoken/load.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def read_file(blobpath: str) -> bytes:
2727
return resp.content
2828

2929

30-
def check_hash(data: bytes, hash: str) -> bool:
31-
data_hash = hashlib.sha256(data).hexdigest()
32-
return data_hash == hash
30+
def check_hash(data: bytes, expected_hash: str) -> bool:
31+
actual_hash = hashlib.sha256(data).hexdigest()
32+
return actual_hash == expected_hash
3333

3434

35-
def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
35+
def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes:
3636
user_specified_cache = True
3737
if "TIKTOKEN_CACHE_DIR" in os.environ:
3838
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
@@ -52,13 +52,15 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
5252
if os.path.exists(cache_path):
5353
with open(cache_path, "rb") as f:
5454
data = f.read()
55-
if expected_hash and not check_hash(data, expected_hash):
56-
raise ValueError(
57-
f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). "
58-
f"Please delete the cache file at {cache_path} and try again."
59-
)
55+
if expected_hash is None or check_hash(data, expected_hash):
6056
return data
6157

58+
# the cached file does not match the hash, remove it and re-fetch
59+
try:
60+
os.remove(cache_path)
61+
except OSError:
62+
pass
63+
6264
contents = read_file(blobpath)
6365
if expected_hash and not check_hash(contents, expected_hash):
6466
raise ValueError(
@@ -81,7 +83,10 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
8183

8284

8385
def data_gym_to_mergeable_bpe_ranks(
84-
vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None
86+
vocab_bpe_file: str,
87+
encoder_json_file: str,
88+
vocab_bpe_hash: Optional[str] = None,
89+
encoder_json_hash: Optional[str] = None,
8590
) -> dict[bytes, int]:
8691
# NB: do not add caching to this function
8792
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
@@ -135,7 +140,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
135140
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
136141

137142

138-
def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]:
143+
def load_tiktoken_bpe(
144+
tiktoken_bpe_file: str, expected_hash: Optional[str] = None
145+
) -> dict[bytes, int]:
139146
# NB: do not add caching to this function
140147
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
141148
return {

0 commit comments

Comments
 (0)