Random modifications that slightly improve the chances of this not blowing up.

Walter will fix it for real.
This commit is contained in:
Guido van Rossum 2007-05-09 23:40:37 +00:00
parent ccf4f0f68d
commit 0e02abb791

View file

@ -87,7 +87,7 @@ def ToASCII(label):
raise UnicodeError("label empty or too long") raise UnicodeError("label empty or too long")
# Step 5: Check ACE prefix # Step 5: Check ACE prefix
if label.startswith(uace_prefix): if label.startswith(ace_prefix):
raise UnicodeError("Label starts with ACE prefix") raise UnicodeError("Label starts with ACE prefix")
# Step 6: Encode with PUNYCODE # Step 6: Encode with PUNYCODE
@ -103,7 +103,7 @@ def ToASCII(label):
def ToUnicode(label): def ToUnicode(label):
# Step 1: Check for ASCII # Step 1: Check for ASCII
if isinstance(label, str): if isinstance(label, bytes):
pure_ascii = True pure_ascii = True
else: else:
try: try:
@ -150,19 +150,19 @@ class Codec(codecs.Codec):
raise UnicodeError("unsupported error handling "+errors) raise UnicodeError("unsupported error handling "+errors)
if not input: if not input:
return "", 0 return b"", 0
result = [] result = []
labels = dots.split(input) labels = dots.split(input)
if labels and len(labels[-1])==0: if labels and len(labels[-1])==0:
trailing_dot = '.' trailing_dot = b'.'
del labels[-1] del labels[-1]
else: else:
trailing_dot = '' trailing_dot = b''
for label in labels: for label in labels:
result.append(ToASCII(label)) result.append(ToASCII(label))
# Join with U+002E # Join with U+002E
return ".".join(result)+trailing_dot, len(input) return b".".join(result)+trailing_dot, len(input)
def decode(self,input,errors='strict'): def decode(self,input,errors='strict'):
@ -173,13 +173,12 @@ class Codec(codecs.Codec):
return "", 0 return "", 0
# IDNA allows decoding to operate on Unicode strings, too. # IDNA allows decoding to operate on Unicode strings, too.
if isinstance(input, str): if isinstance(input, bytes):
labels = dots.split(input) labels = dots.split(input)
else: else:
# Must be ASCII string # Force to bytes
input = str(input) input = bytes(input)
str(input, "ascii") labels = input.split(b".")
labels = input.split(".")
if labels and len(labels[-1]) == 0: if labels and len(labels[-1]) == 0:
trailing_dot = '.' trailing_dot = '.'