fix(base64): allow padded chunks mid-stream

This commit is contained in:
karanabe 2025-11-09 02:41:15 +09:00 committed by Sylvestre Ledru
parent 859a1ed2e3
commit 92bb655b0c
3 changed files with 111 additions and 19 deletions

View file

@ -171,18 +171,16 @@ pub fn get_input(config: &Config) -> UResult<Box<dyn ReadSeek>> {
}
}
/// Determines if the input buffer ends with padding ('=') after trimming trailing whitespace.
/// Determines if the input buffer contains any padding ('=') ignoring trailing whitespace.
fn read_and_has_padding<R: Read>(input: &mut R) -> UResult<(bool, Vec<u8>)> {
let mut buf = Vec::new();
input
.read_to_end(&mut buf)
.map_err(|err| USimpleError::new(1, format_read_error(err.kind())))?;
// Reverse iterator and skip trailing whitespace without extra collections
let has_padding = buf
.iter()
.rfind(|&&byte| !byte.is_ascii_whitespace())
.is_some_and(|&byte| byte == b'=');
// Treat the stream as padded if any '=' exists (GNU coreutils continues decoding
// even when padding bytes are followed by more data).
let has_padding = buf.contains(&b'=');
Ok((has_padding, buf))
}
@ -665,6 +663,8 @@ mod tests {
("aGVsbG8sIHdvcmxkIQ== \n", true),
("aGVsbG8sIHdvcmxkIQ=", true),
("aGVsbG8sIHdvcmxkIQ= ", true),
("MTIzNA==MTIzNA", true),
("MTIzNA==\nMTIzNA", true),
("aGVsbG8sIHdvcmxkIQ \n", false),
("aGVsbG8sIHdvcmxkIQ", false),
];

View file

@ -22,6 +22,26 @@ pub struct Base64SimdWrapper {
}
impl Base64SimdWrapper {
fn decode_with_standard(input: &[u8], output: &mut Vec<u8>) -> Result<(), ()> {
match base64_simd::STANDARD.decode_to_vec(input) {
Ok(decoded_bytes) => {
output.extend_from_slice(&decoded_bytes);
Ok(())
}
Err(_) => Err(()),
}
}
fn decode_with_no_pad(input: &[u8], output: &mut Vec<u8>) -> Result<(), ()> {
match base64_simd::STANDARD_NO_PAD.decode_to_vec(input) {
Ok(decoded_bytes) => {
output.extend_from_slice(&decoded_bytes);
Ok(())
}
Err(_) => Err(()),
}
}
pub fn new(
use_padding: bool,
valid_decoding_multiple: usize,
@ -47,22 +67,64 @@ impl SupportsFastDecodeAndEncode for Base64SimdWrapper {
}
fn decode_into_vec(&self, input: &[u8], output: &mut Vec<u8>) -> UResult<()> {
let decoded = if self.use_padding {
base64_simd::STANDARD.decode_to_vec(input)
let original_len = output.len();
let decode_result = if self.use_padding {
// GNU coreutils keeps decoding even when '=' appears before the true end
// of the stream (e.g. concatenated padded chunks). Mirror that logic
// by splitting at each '='-containing quantum, decoding those 4-byte
// groups with the padded variant, then letting the remainder fall back
// to whichever alphabet fits.
let mut start = 0usize;
while start < input.len() {
let remaining = &input[start..];
if remaining.is_empty() {
break;
}
if let Some(eq_rel_idx) = remaining.iter().position(|&b| b == b'=') {
let blocks = (eq_rel_idx / 4) + 1;
let segment_len = blocks * 4;
if segment_len > remaining.len() {
return Err(USimpleError::new(1, "error: invalid input".to_owned()));
}
if Self::decode_with_standard(&remaining[..segment_len], output).is_err() {
return Err(USimpleError::new(1, "error: invalid input".to_owned()));
}
start += segment_len;
} else {
// If there are no more '=' bytes the tail might still be padded
// (len % 4 == 0) or purposely unpadded (GNU --ignore-garbage or
// concatenated streams), so select the matching alphabet.
let decoder = if remaining.len() % 4 == 0 {
Self::decode_with_standard
} else {
Self::decode_with_no_pad
};
if decoder(remaining, output).is_err() {
return Err(USimpleError::new(1, "error: invalid input".to_owned()));
}
break;
}
}
Ok(())
} else {
base64_simd::STANDARD_NO_PAD.decode_to_vec(input)
Self::decode_with_no_pad(input, output)
.map_err(|_| USimpleError::new(1, "error: invalid input".to_owned()))
};
match decoded {
Ok(decoded_bytes) => {
output.extend_from_slice(&decoded_bytes);
Ok(())
}
Err(_) => {
// Restore original length on error
output.truncate(output.len());
Err(USimpleError::new(1, "error: invalid input".to_owned()))
}
if let Err(err) = decode_result {
output.truncate(original_len);
Err(err)
} else {
Ok(())
}
}

View file

@ -2,6 +2,9 @@
//
// For the full copyright and license information, please view the LICENSE
// file that was distributed with this source code.
// spell-checker:ignore unpadded, QUJD
#[cfg(target_os = "linux")]
use uutests::at_and_ucmd;
use uutests::new_ucmd;
@ -108,6 +111,33 @@ fn test_decode_repeat_flags() {
.stdout_only("hello, world!");
}
#[test]
fn test_decode_padded_block_followed_by_unpadded_tail() {
new_ucmd!()
.arg("--decode")
.pipe_in("MTIzNA==MTIzNA")
.succeeds()
.stdout_only("12341234");
}
#[test]
fn test_decode_padded_block_followed_by_aligned_tail() {
new_ucmd!()
.arg("--decode")
.pipe_in("MTIzNA==QUJD")
.succeeds()
.stdout_only("1234ABC");
}
#[test]
fn test_decode_unpadded_stream_without_equals() {
new_ucmd!()
.arg("--decode")
.pipe_in("MTIzNA")
.succeeds()
.stdout_only("1234");
}
#[test]
fn test_garbage() {
let input = "aGVsbG8sIHdvcmxkIQ==\0"; // spell-checker:disable-line