Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions mypyc/lib-rt/base64/librt_base64.c
Original file line number Diff line number Diff line change
Expand Up @@ -240,39 +240,21 @@ b64decode_handle_invalid_input(
return PyErr_NoMemory();
}

// Copy base64 characters and some padding to the new buffer
int pad_chars = 0;
// Copy base64 characters to the new buffer. Ignore padding to conform to RFC 4648 section 3.3.
for (size_t i = 0; i < srclen; i++) {
char c = src[i];
if (is_valid_base64_char(c, false)) {
newbuf[newbuf_len++] = c;
pad_chars = 0;
} else if (c == '=') {
// Copy a necessary amount of padding
int remainder = newbuf_len % 4;
if (remainder == 0) {
// No padding needed
break;
}
int numpad = 4 - remainder;
// Check that there is at least the required amount padding (CPython ignores
// extra padding)
while (numpad > 0) {
if (i == srclen || src[i] != '=') {
break;
}
newbuf[newbuf_len++] = '=';
i++;
numpad--;
// Skip non-base64 alphabet characters within padding
while (i < srclen && !is_valid_base64_char(src[i], true)) {
i++;
}
}
break;
pad_chars++;
}
}

int quad_pos = newbuf_len % 4;
// Stdlib always performs a non-strict padding check
if (newbuf_len % 4 != 0) {
if (quad_pos != 0 && quad_pos + pad_chars < 4) {
if (freesrc) {
PyMem_Free((void *)src);
}
Expand All @@ -282,6 +264,15 @@ b64decode_handle_invalid_input(
return NULL;
}

if (quad_pos != 0) {
// Add padding at the end to make the input length a multiple of 4. We know that this padding
// is present in src because otherwise we would report the "Incorrect padding" error above.
while (quad_pos < 4) {
newbuf[newbuf_len++] = '=';
quad_pos++;
}
}

size_t outlen = max_out;
int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0);
PyMem_Free(newbuf);
Expand Down
21 changes: 14 additions & 7 deletions mypyc/test-data/run-base64.test
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from typing import Any, cast
import base64
import binascii
import random
import sys

from librt.base64 import b64encode, b64decode, urlsafe_b64encode, urlsafe_b64decode

Expand Down Expand Up @@ -121,6 +122,14 @@ def test_decode_with_non_base64_chars() -> None:
check_decode(b"e" + b + b"A==", encoded=True)
check_decode(b"eA=" + b + b"=", encoded=True)

def has_stdlib_b64decode_bugfix() -> bool:
# stdlib b64decode has a bug in older python versions where it skips processing the input data
# after the first padded quad. It was changed to conform to RFC 4648 section 3.3 in cpython 3.13.13+,
# 3.14.4+ and 3.15+. The librt implementation was changed to match the correct behavior regardless
# of python version so some inputs result in different results than stdlib on older python.
_, minor, micro, _, _ = sys.version_info
return minor > 14 or (minor == 14 and micro >= 4) or (minor == 13 and micro >= 13)

def check_decode_error(b: bytes, ignore_stdlib: bool = False) -> None:
if not ignore_stdlib:
with assertRaises(binascii.Error):
Expand All @@ -135,9 +144,7 @@ def test_decode_with_invalid_padding() -> None:
check_decode_error(b"eA=")
check_decode_error(b"eHk")
check_decode_error(b"eA = ")

# Here stdlib behavior seems nonsensical, so we don't try to duplicate it
check_decode_error(b"eA=a=", ignore_stdlib=True)
check_decode_error(b"eA==x", ignore_stdlib=not has_stdlib_b64decode_bugfix())

def test_decode_with_extra_data_after_padding() -> None:
check_decode(b"=", encoded=True)
Expand All @@ -146,10 +153,10 @@ def test_decode_with_extra_data_after_padding() -> None:
check_decode(b"====", encoded=True)
check_decode(b"eA===", encoded=True)
check_decode(b"eHk==", encoded=True)
# TODO: behavior in these cases changed in Python 3.14.4, we should match that.
# check_decode(b"eA==x", encoded=True)
# check_decode(b"eHk=x", encoded=True)
# check_decode(b"eA==abc=======efg", encoded=True)
if has_stdlib_b64decode_bugfix():
check_decode(b"eA=a=", encoded=True)
check_decode(b"eHk=x", encoded=True)
check_decode(b"eA==abc=======efg", encoded=True)

def test_decode_wrappers() -> None:
funcs: list[Any] = [b64decode, urlsafe_b64decode]
Expand Down
Loading