gh-134938: Add set_pledged_input_size() to ZstdCompressor (GH-135010)

This commit is contained in:
Emma Smith 2025-06-05 04:31:49 -07:00 committed by GitHub
parent 3d396ab759
commit 4b44b3409a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 285 additions and 2 deletions

View file

@ -395,6 +395,115 @@ class CompressorTestCase(unittest.TestCase):
c = ZstdCompressor()
self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'')
def test_set_pledged_input_size(self):
DAT = DECOMPRESSED_100_PLUS_32KB
CHUNK_SIZE = len(DAT) // 3
# wrong value
c = ZstdCompressor()
with self.assertRaisesRegex(ValueError,
r'should be a positive int less than \d+'):
c.set_pledged_input_size(-300)
# overflow
with self.assertRaisesRegex(ValueError,
r'should be a positive int less than \d+'):
c.set_pledged_input_size(2**64)
# ZSTD_CONTENTSIZE_ERROR is invalid
with self.assertRaisesRegex(ValueError,
r'should be a positive int less than \d+'):
c.set_pledged_input_size(2**64-2)
# ZSTD_CONTENTSIZE_UNKNOWN should use None
with self.assertRaisesRegex(ValueError,
r'should be a positive int less than \d+'):
c.set_pledged_input_size(2**64-1)
# check valid values are settable
c.set_pledged_input_size(2**63)
c.set_pledged_input_size(2**64-3)
# check that zero means empty frame
c = ZstdCompressor(level=1)
c.set_pledged_input_size(0)
c.compress(b'')
dat = c.flush()
ret = get_frame_info(dat)
self.assertEqual(ret.decompressed_size, 0)
# wrong mode
c = ZstdCompressor(level=1)
c.compress(b'123456')
self.assertEqual(c.last_mode, c.CONTINUE)
with self.assertRaisesRegex(ValueError,
r'last_mode == FLUSH_FRAME'):
c.set_pledged_input_size(300)
# None value
c = ZstdCompressor(level=1)
c.set_pledged_input_size(None)
dat = c.compress(DAT) + c.flush()
ret = get_frame_info(dat)
self.assertEqual(ret.decompressed_size, None)
# correct value
c = ZstdCompressor(level=1)
c.set_pledged_input_size(len(DAT))
chunks = []
posi = 0
while posi < len(DAT):
dat = c.compress(DAT[posi:posi+CHUNK_SIZE])
posi += CHUNK_SIZE
chunks.append(dat)
dat = c.flush()
chunks.append(dat)
chunks = b''.join(chunks)
ret = get_frame_info(chunks)
self.assertEqual(ret.decompressed_size, len(DAT))
self.assertEqual(decompress(chunks), DAT)
c.set_pledged_input_size(len(DAT)) # the second frame
dat = c.compress(DAT) + c.flush()
ret = get_frame_info(dat)
self.assertEqual(ret.decompressed_size, len(DAT))
self.assertEqual(decompress(dat), DAT)
# not enough data
c = ZstdCompressor(level=1)
c.set_pledged_input_size(len(DAT)+1)
for start in range(0, len(DAT), CHUNK_SIZE):
end = min(start+CHUNK_SIZE, len(DAT))
_dat = c.compress(DAT[start:end])
with self.assertRaises(ZstdError):
c.flush()
# too much data
c = ZstdCompressor(level=1)
c.set_pledged_input_size(len(DAT))
for start in range(0, len(DAT), CHUNK_SIZE):
end = min(start+CHUNK_SIZE, len(DAT))
_dat = c.compress(DAT[start:end])
with self.assertRaises(ZstdError):
c.compress(b'extra', ZstdCompressor.FLUSH_FRAME)
# content size not set if content_size_flag == 0
c = ZstdCompressor(options={CompressionParameter.content_size_flag: 0})
c.set_pledged_input_size(10)
dat1 = c.compress(b"hello")
dat2 = c.compress(b"world")
dat3 = c.flush()
frame_data = get_frame_info(dat1 + dat2 + dat3)
self.assertIsNone(frame_data.decompressed_size)
class DecompressorTestCase(unittest.TestCase):
def test_simple_decompress_bad_args(self):