Fix quopri to operate consistently on bytes.

This commit is contained in:
Martin v. Löwis 2007-07-28 17:52:25 +00:00
parent f3f0c611dd
commit c582bfca26
2 changed files with 87 additions and 75 deletions

View file

@ -6,10 +6,10 @@
__all__ = ["encode", "decode", "encodestring", "decodestring"]
ESCAPE = '='
ESCAPE = b'='
MAXLINESIZE = 76
HEX = '0123456789ABCDEF'
EMPTYSTRING = ''
HEX = b'0123456789ABCDEF'
EMPTYSTRING = b''
try:
from binascii import a2b_qp, b2a_qp
@ -19,23 +19,25 @@ except ImportError:
def needsquoting(c, quotetabs, header):
"""Decide whether a particular character needs to be quoted.
"""Decide whether a particular byte ordinal needs to be quoted.
The 'quotetabs' flag indicates whether embedded tabs and spaces should be
quoted. Note that line-ending tabs and spaces are always encoded, as per
RFC 1521.
"""
if c in ' \t':
assert isinstance(c, bytes)
if c in b' \t':
return quotetabs
# if header, we have to escape _ because _ is used to escape space
if c == '_':
if c == b'_':
return header
return c == ESCAPE or not (' ' <= c <= '~')
return c == ESCAPE or not (b' ' <= c <= b'~')
def quote(c):
"""Quote a single character."""
i = ord(c)
return ESCAPE + HEX[i//16] + HEX[i%16]
assert isinstance(c, bytes) and len(c)==1
c = ord(c)
return ESCAPE + bytes((HEX[c//16], HEX[c%16]))
@ -56,12 +58,12 @@ def encode(input, output, quotetabs, header = 0):
output.write(odata)
return
def write(s, output=output, lineEnd='\n'):
def write(s, output=output, lineEnd=b'\n'):
# RFC 1521 requires that the line ending in a space or tab must have
# that trailing character encoded.
if s and s[-1:] in ' \t':
output.write(s[:-1] + quote(s[-1]) + lineEnd)
elif s == '.':
if s and s[-1:] in b' \t':
output.write(s[:-1] + quote(s[-1:]) + lineEnd)
elif s == b'.':
output.write(quote(s) + lineEnd)
else:
output.write(s + lineEnd)
@ -73,16 +75,17 @@ def encode(input, output, quotetabs, header = 0):
break
outline = []
# Strip off any readline induced trailing newline
stripped = ''
if line[-1:] == '\n':
stripped = b''
if line[-1:] == b'\n':
line = line[:-1]
stripped = '\n'
stripped = b'\n'
# Calculate the un-length-limited encoded line
for c in line:
c = bytes((c,))
if needsquoting(c, quotetabs, header):
c = quote(c)
if header and c == ' ':
outline.append('_')
if header and c == b' ':
outline.append(b'_')
else:
outline.append(c)
# First, write out the previous line
@ -94,7 +97,7 @@ def encode(input, output, quotetabs, header = 0):
while len(thisline) > MAXLINESIZE:
# Don't forget to include the soft line break `=' sign in the
# length calculation!
write(thisline[:MAXLINESIZE-1], lineEnd='=\n')
write(thisline[:MAXLINESIZE-1], lineEnd=b'=\n')
thisline = thisline[MAXLINESIZE-1:]
# Write out the current line
prevline = thisline
@ -105,9 +108,9 @@ def encode(input, output, quotetabs, header = 0):
def encodestring(s, quotetabs = 0, header = 0):
if b2a_qp is not None:
return b2a_qp(s, quotetabs = quotetabs, header = header)
from io import StringIO
infp = StringIO(s)
outfp = StringIO()
from io import BytesIO
infp = BytesIO(s)
outfp = BytesIO()
encode(infp, outfp, quotetabs, header)
return outfp.getvalue()
@ -124,44 +127,44 @@ def decode(input, output, header = 0):
output.write(odata)
return
new = ''
new = b''
while 1:
line = input.readline()
if not line: break
i, n = 0, len(line)
if n > 0 and line[n-1] == '\n':
if n > 0 and line[n-1:n] == b'\n':
partial = 0; n = n-1
# Strip trailing whitespace
while n > 0 and line[n-1] in " \t\r":
while n > 0 and line[n-1:n] in b" \t\r":
n = n-1
else:
partial = 1
while i < n:
c = line[i]
if c == '_' and header:
new = new + ' '; i = i+1
c = line[i:i+1]
if c == b'_' and header:
new = new + b' '; i = i+1
elif c != ESCAPE:
new = new + c; i = i+1
elif i+1 == n and not partial:
partial = 1; break
elif i+1 < n and line[i+1] == ESCAPE:
new = new + ESCAPE; i = i+2
elif i+2 < n and ishex(line[i+1]) and ishex(line[i+2]):
new = new + chr(unhex(line[i+1:i+3])); i = i+3
elif i+2 < n and ishex(line[i+1:i+2]) and ishex(line[i+2:i+3]):
new = new + bytes((unhex(line[i+1:i+3]),)); i = i+3
else: # Bad escape sequence -- leave it in
new = new + c; i = i+1
if not partial:
output.write(new + '\n')
new = ''
output.write(new + b'\n')
new = b''
if new:
output.write(new)
def decodestring(s, header = 0):
if a2b_qp is not None:
return a2b_qp(s, header = header)
from io import StringIO
infp = StringIO(s)
outfp = StringIO()
from io import BytesIO
infp = BytesIO(s)
outfp = BytesIO()
decode(infp, outfp, header = header)
return outfp.getvalue()
@ -169,21 +172,23 @@ def decodestring(s, header = 0):
# Other helper functions
def ishex(c):
"""Return true if the character 'c' is a hexadecimal digit."""
return '0' <= c <= '9' or 'a' <= c <= 'f' or 'A' <= c <= 'F'
"""Return true if the byte ordinal 'c' is a hexadecimal digit in ASCII."""
assert isinstance(c, bytes)
return b'0' <= c <= b'9' or b'a' <= c <= b'f' or b'A' <= c <= b'F'
def unhex(s):
"""Get the integer value of a hexadecimal number."""
bits = 0
for c in s:
if '0' <= c <= '9':
c = bytes((c,))
if b'0' <= c <= b'9':
i = ord('0')
elif 'a' <= c <= 'f':
elif b'a' <= c <= b'f':
i = ord('a')-10
elif 'A' <= c <= 'F':
i = ord('A')-10
elif b'A' <= c <= b'F':
i = ord(b'A')-10
else:
break
assert False, "non-hex digit "+repr(c)
bits = bits*16 + (ord(c) - i)
return bits
@ -214,18 +219,18 @@ def main():
sts = 0
for file in args:
if file == '-':
fp = sys.stdin
fp = sys.stdin.buffer
else:
try:
fp = open(file)
fp = open(file, "rb")
except IOError as msg:
sys.stderr.write("%s: can't open (%s)\n" % (file, msg))
sts = 1
continue
if deco:
decode(fp, sys.stdout)
decode(fp, sys.stdout.buffer)
else:
encode(fp, sys.stdout, tabs)
encode(fp, sys.stdout.buffer, tabs)
if fp is not sys.stdin:
fp.close()
if sts: