bpo-44439: BZ2File.write() / LZMAFile.write() handle buffer protocol correctly (GH-26764)

No longer use len() to get the length of the input data. For some buffer protocol objects,
the length obtained by using len() is wrong.
This commit is contained in:
Ma Lin 2021-06-22 15:04:23 +08:00 committed by GitHub
parent 92c2e91580
commit bc6c12c72a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 9 deletions

View file

@ -219,14 +219,22 @@ class BZ2File(_compression.BaseStream):
"""Write a byte string to the file. """Write a byte string to the file.
Returns the number of uncompressed bytes written, which is Returns the number of uncompressed bytes written, which is
always len(data). Note that due to buffering, the file on disk always the length of data in bytes. Note that due to buffering,
may not reflect the data written until close() is called. the file on disk may not reflect the data written until close()
is called.
""" """
self._check_can_write() self._check_can_write()
if isinstance(data, (bytes, bytearray)):
length = len(data)
else:
# accept any data that supports the buffer protocol
data = memoryview(data)
length = data.nbytes
compressed = self._compressor.compress(data) compressed = self._compressor.compress(data)
self._fp.write(compressed) self._fp.write(compressed)
self._pos += len(data) self._pos += length
return len(data) return length
def writelines(self, seq): def writelines(self, seq):
"""Write a sequence of byte strings to the file. """Write a sequence of byte strings to the file.

View file

@ -278,7 +278,7 @@ class GzipFile(_compression.BaseStream):
if self.fileobj is None: if self.fileobj is None:
raise ValueError("write() on closed GzipFile object") raise ValueError("write() on closed GzipFile object")
if isinstance(data, bytes): if isinstance(data, (bytes, bytearray)):
length = len(data) length = len(data)
else: else:
# accept any data that supports the buffer protocol # accept any data that supports the buffer protocol

View file

@ -229,14 +229,22 @@ class LZMAFile(_compression.BaseStream):
"""Write a bytes object to the file. """Write a bytes object to the file.
Returns the number of uncompressed bytes written, which is Returns the number of uncompressed bytes written, which is
always len(data). Note that due to buffering, the file on disk always the length of data in bytes. Note that due to buffering,
may not reflect the data written until close() is called. the file on disk may not reflect the data written until close()
is called.
""" """
self._check_can_write() self._check_can_write()
if isinstance(data, (bytes, bytearray)):
length = len(data)
else:
# accept any data that supports the buffer protocol
data = memoryview(data)
length = data.nbytes
compressed = self._compressor.compress(data) compressed = self._compressor.compress(data)
self._fp.write(compressed) self._fp.write(compressed)
self._pos += len(data) self._pos += length
return len(data) return length
def seek(self, offset, whence=io.SEEK_SET): def seek(self, offset, whence=io.SEEK_SET):
"""Change the file position. """Change the file position.

View file

@ -1,6 +1,7 @@
from test import support from test import support
from test.support import bigmemtest, _4G from test.support import bigmemtest, _4G
import array
import unittest import unittest
from io import BytesIO, DEFAULT_BUFFER_SIZE from io import BytesIO, DEFAULT_BUFFER_SIZE
import os import os
@ -620,6 +621,14 @@ class BZ2FileTest(BaseTest):
with BZ2File(BytesIO(truncated[:i])) as f: with BZ2File(BytesIO(truncated[:i])) as f:
self.assertRaises(EOFError, f.read, 1) self.assertRaises(EOFError, f.read, 1)
def test_issue44439(self):
q = array.array('Q', [1, 2, 3, 4, 5])
LENGTH = len(q) * q.itemsize
with BZ2File(BytesIO(), 'w') as f:
self.assertEqual(f.write(q), LENGTH)
self.assertEqual(f.tell(), LENGTH)
class BZ2CompressorTest(BaseTest): class BZ2CompressorTest(BaseTest):
def testCompress(self): def testCompress(self):

View file

@ -592,6 +592,15 @@ class TestGzip(BaseTest):
with gzip.open(self.filename, "rb") as f: with gzip.open(self.filename, "rb") as f:
f._buffer.raw._fp.prepend() f._buffer.raw._fp.prepend()
def test_issue44439(self):
q = array.array('Q', [1, 2, 3, 4, 5])
LENGTH = len(q) * q.itemsize
with gzip.GzipFile(fileobj=io.BytesIO(), mode='w') as f:
self.assertEqual(f.write(q), LENGTH)
self.assertEqual(f.tell(), LENGTH)
class TestOpen(BaseTest): class TestOpen(BaseTest):
def test_binary_modes(self): def test_binary_modes(self):
uncompressed = data1 * 50 uncompressed = data1 * 50

View file

@ -1,4 +1,5 @@
import _compression import _compression
import array
from io import BytesIO, UnsupportedOperation, DEFAULT_BUFFER_SIZE from io import BytesIO, UnsupportedOperation, DEFAULT_BUFFER_SIZE
import os import os
import pathlib import pathlib
@ -1231,6 +1232,14 @@ class FileTestCase(unittest.TestCase):
self.assertTrue(d2.eof) self.assertTrue(d2.eof)
self.assertEqual(out1 + out2, entire) self.assertEqual(out1 + out2, entire)
def test_issue44439(self):
q = array.array('Q', [1, 2, 3, 4, 5])
LENGTH = len(q) * q.itemsize
with LZMAFile(BytesIO(), 'w') as f:
self.assertEqual(f.write(q), LENGTH)
self.assertEqual(f.tell(), LENGTH)
class OpenTestCase(unittest.TestCase): class OpenTestCase(unittest.TestCase):

View file

@ -0,0 +1,3 @@
Fix in :meth:`bz2.BZ2File.write` / :meth:`lzma.LZMAFile.write` methods, when
the input data is an object that supports the buffer protocol, the file length
may be wrong.