mirror of
https://github.com/python/cpython.git
synced 2025-07-07 19:35:27 +00:00
Fix quopri to operate consistently on bytes.
This commit is contained in:
parent
f3f0c611dd
commit
c582bfca26
2 changed files with 87 additions and 75 deletions
|
@ -6,10 +6,10 @@
|
|||
|
||||
__all__ = ["encode", "decode", "encodestring", "decodestring"]
|
||||
|
||||
ESCAPE = '='
|
||||
ESCAPE = b'='
|
||||
MAXLINESIZE = 76
|
||||
HEX = '0123456789ABCDEF'
|
||||
EMPTYSTRING = ''
|
||||
HEX = b'0123456789ABCDEF'
|
||||
EMPTYSTRING = b''
|
||||
|
||||
try:
|
||||
from binascii import a2b_qp, b2a_qp
|
||||
|
@ -19,23 +19,25 @@ except ImportError:
|
|||
|
||||
|
||||
def needsquoting(c, quotetabs, header):
|
||||
"""Decide whether a particular character needs to be quoted.
|
||||
"""Decide whether a particular byte ordinal needs to be quoted.
|
||||
|
||||
The 'quotetabs' flag indicates whether embedded tabs and spaces should be
|
||||
quoted. Note that line-ending tabs and spaces are always encoded, as per
|
||||
RFC 1521.
|
||||
"""
|
||||
if c in ' \t':
|
||||
assert isinstance(c, bytes)
|
||||
if c in b' \t':
|
||||
return quotetabs
|
||||
# if header, we have to escape _ because _ is used to escape space
|
||||
if c == '_':
|
||||
if c == b'_':
|
||||
return header
|
||||
return c == ESCAPE or not (' ' <= c <= '~')
|
||||
return c == ESCAPE or not (b' ' <= c <= b'~')
|
||||
|
||||
def quote(c):
|
||||
"""Quote a single character."""
|
||||
i = ord(c)
|
||||
return ESCAPE + HEX[i//16] + HEX[i%16]
|
||||
assert isinstance(c, bytes) and len(c)==1
|
||||
c = ord(c)
|
||||
return ESCAPE + bytes((HEX[c//16], HEX[c%16]))
|
||||
|
||||
|
||||
|
||||
|
@ -56,12 +58,12 @@ def encode(input, output, quotetabs, header = 0):
|
|||
output.write(odata)
|
||||
return
|
||||
|
||||
def write(s, output=output, lineEnd='\n'):
|
||||
def write(s, output=output, lineEnd=b'\n'):
|
||||
# RFC 1521 requires that the line ending in a space or tab must have
|
||||
# that trailing character encoded.
|
||||
if s and s[-1:] in ' \t':
|
||||
output.write(s[:-1] + quote(s[-1]) + lineEnd)
|
||||
elif s == '.':
|
||||
if s and s[-1:] in b' \t':
|
||||
output.write(s[:-1] + quote(s[-1:]) + lineEnd)
|
||||
elif s == b'.':
|
||||
output.write(quote(s) + lineEnd)
|
||||
else:
|
||||
output.write(s + lineEnd)
|
||||
|
@ -73,16 +75,17 @@ def encode(input, output, quotetabs, header = 0):
|
|||
break
|
||||
outline = []
|
||||
# Strip off any readline induced trailing newline
|
||||
stripped = ''
|
||||
if line[-1:] == '\n':
|
||||
stripped = b''
|
||||
if line[-1:] == b'\n':
|
||||
line = line[:-1]
|
||||
stripped = '\n'
|
||||
stripped = b'\n'
|
||||
# Calculate the un-length-limited encoded line
|
||||
for c in line:
|
||||
c = bytes((c,))
|
||||
if needsquoting(c, quotetabs, header):
|
||||
c = quote(c)
|
||||
if header and c == ' ':
|
||||
outline.append('_')
|
||||
if header and c == b' ':
|
||||
outline.append(b'_')
|
||||
else:
|
||||
outline.append(c)
|
||||
# First, write out the previous line
|
||||
|
@ -94,7 +97,7 @@ def encode(input, output, quotetabs, header = 0):
|
|||
while len(thisline) > MAXLINESIZE:
|
||||
# Don't forget to include the soft line break `=' sign in the
|
||||
# length calculation!
|
||||
write(thisline[:MAXLINESIZE-1], lineEnd='=\n')
|
||||
write(thisline[:MAXLINESIZE-1], lineEnd=b'=\n')
|
||||
thisline = thisline[MAXLINESIZE-1:]
|
||||
# Write out the current line
|
||||
prevline = thisline
|
||||
|
@ -105,9 +108,9 @@ def encode(input, output, quotetabs, header = 0):
|
|||
def encodestring(s, quotetabs = 0, header = 0):
|
||||
if b2a_qp is not None:
|
||||
return b2a_qp(s, quotetabs = quotetabs, header = header)
|
||||
from io import StringIO
|
||||
infp = StringIO(s)
|
||||
outfp = StringIO()
|
||||
from io import BytesIO
|
||||
infp = BytesIO(s)
|
||||
outfp = BytesIO()
|
||||
encode(infp, outfp, quotetabs, header)
|
||||
return outfp.getvalue()
|
||||
|
||||
|
@ -124,44 +127,44 @@ def decode(input, output, header = 0):
|
|||
output.write(odata)
|
||||
return
|
||||
|
||||
new = ''
|
||||
new = b''
|
||||
while 1:
|
||||
line = input.readline()
|
||||
if not line: break
|
||||
i, n = 0, len(line)
|
||||
if n > 0 and line[n-1] == '\n':
|
||||
if n > 0 and line[n-1:n] == b'\n':
|
||||
partial = 0; n = n-1
|
||||
# Strip trailing whitespace
|
||||
while n > 0 and line[n-1] in " \t\r":
|
||||
while n > 0 and line[n-1:n] in b" \t\r":
|
||||
n = n-1
|
||||
else:
|
||||
partial = 1
|
||||
while i < n:
|
||||
c = line[i]
|
||||
if c == '_' and header:
|
||||
new = new + ' '; i = i+1
|
||||
c = line[i:i+1]
|
||||
if c == b'_' and header:
|
||||
new = new + b' '; i = i+1
|
||||
elif c != ESCAPE:
|
||||
new = new + c; i = i+1
|
||||
elif i+1 == n and not partial:
|
||||
partial = 1; break
|
||||
elif i+1 < n and line[i+1] == ESCAPE:
|
||||
new = new + ESCAPE; i = i+2
|
||||
elif i+2 < n and ishex(line[i+1]) and ishex(line[i+2]):
|
||||
new = new + chr(unhex(line[i+1:i+3])); i = i+3
|
||||
elif i+2 < n and ishex(line[i+1:i+2]) and ishex(line[i+2:i+3]):
|
||||
new = new + bytes((unhex(line[i+1:i+3]),)); i = i+3
|
||||
else: # Bad escape sequence -- leave it in
|
||||
new = new + c; i = i+1
|
||||
if not partial:
|
||||
output.write(new + '\n')
|
||||
new = ''
|
||||
output.write(new + b'\n')
|
||||
new = b''
|
||||
if new:
|
||||
output.write(new)
|
||||
|
||||
def decodestring(s, header = 0):
|
||||
if a2b_qp is not None:
|
||||
return a2b_qp(s, header = header)
|
||||
from io import StringIO
|
||||
infp = StringIO(s)
|
||||
outfp = StringIO()
|
||||
from io import BytesIO
|
||||
infp = BytesIO(s)
|
||||
outfp = BytesIO()
|
||||
decode(infp, outfp, header = header)
|
||||
return outfp.getvalue()
|
||||
|
||||
|
@ -169,21 +172,23 @@ def decodestring(s, header = 0):
|
|||
|
||||
# Other helper functions
|
||||
def ishex(c):
|
||||
"""Return true if the character 'c' is a hexadecimal digit."""
|
||||
return '0' <= c <= '9' or 'a' <= c <= 'f' or 'A' <= c <= 'F'
|
||||
"""Return true if the byte ordinal 'c' is a hexadecimal digit in ASCII."""
|
||||
assert isinstance(c, bytes)
|
||||
return b'0' <= c <= b'9' or b'a' <= c <= b'f' or b'A' <= c <= b'F'
|
||||
|
||||
def unhex(s):
|
||||
"""Get the integer value of a hexadecimal number."""
|
||||
bits = 0
|
||||
for c in s:
|
||||
if '0' <= c <= '9':
|
||||
c = bytes((c,))
|
||||
if b'0' <= c <= b'9':
|
||||
i = ord('0')
|
||||
elif 'a' <= c <= 'f':
|
||||
elif b'a' <= c <= b'f':
|
||||
i = ord('a')-10
|
||||
elif 'A' <= c <= 'F':
|
||||
i = ord('A')-10
|
||||
elif b'A' <= c <= b'F':
|
||||
i = ord(b'A')-10
|
||||
else:
|
||||
break
|
||||
assert False, "non-hex digit "+repr(c)
|
||||
bits = bits*16 + (ord(c) - i)
|
||||
return bits
|
||||
|
||||
|
@ -214,18 +219,18 @@ def main():
|
|||
sts = 0
|
||||
for file in args:
|
||||
if file == '-':
|
||||
fp = sys.stdin
|
||||
fp = sys.stdin.buffer
|
||||
else:
|
||||
try:
|
||||
fp = open(file)
|
||||
fp = open(file, "rb")
|
||||
except IOError as msg:
|
||||
sys.stderr.write("%s: can't open (%s)\n" % (file, msg))
|
||||
sts = 1
|
||||
continue
|
||||
if deco:
|
||||
decode(fp, sys.stdout)
|
||||
decode(fp, sys.stdout.buffer)
|
||||
else:
|
||||
encode(fp, sys.stdout, tabs)
|
||||
encode(fp, sys.stdout.buffer, tabs)
|
||||
if fp is not sys.stdin:
|
||||
fp.close()
|
||||
if sts:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue