mirror of
https://github.com/python/cpython.git
synced 2025-12-04 08:34:25 +00:00
improvements to test_smtplib per issue2423
merged the socket mock introduced in test_smtpd
This commit is contained in:
parent
0db85e5d46
commit
64b02de010
3 changed files with 188 additions and 75 deletions
153
Lib/test/mock_socket.py
Normal file
153
Lib/test/mock_socket.py
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
"""Mock socket module used by the smtpd and smtplib tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# imported for _GLOBAL_DEFAULT_TIMEOUT
|
||||||
|
import socket as socket_module
|
||||||
|
|
||||||
|
# Mock socket module
|
||||||
|
_defaulttimeout = None
|
||||||
|
_reply_data = None
|
||||||
|
|
||||||
|
# This is used to queue up data to be read through socket.makefile, typically
|
||||||
|
# *before* the socket object is even created. It is intended to handle a single
|
||||||
|
# line which the socket will feed on recv() or makefile().
|
||||||
|
def reply_with(line):
|
||||||
|
global _reply_data
|
||||||
|
_reply_data = line
|
||||||
|
|
||||||
|
|
||||||
|
class MockFile:
|
||||||
|
"""Mock file object returned by MockSocket.makefile().
|
||||||
|
"""
|
||||||
|
def __init__(self, lines):
|
||||||
|
self.lines = lines
|
||||||
|
def readline(self):
|
||||||
|
return self.lines.pop(0) + b'\r\n'
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockSocket:
|
||||||
|
"""Mock socket object used by smtpd and smtplib tests.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
global _reply_data
|
||||||
|
self.output = []
|
||||||
|
self.lines = []
|
||||||
|
if _reply_data:
|
||||||
|
self.lines.append(_reply_data)
|
||||||
|
self.conn = None
|
||||||
|
self.timeout = None
|
||||||
|
|
||||||
|
def queue_recv(self, line):
|
||||||
|
self.lines.append(line)
|
||||||
|
|
||||||
|
def recv(self, bufsize, flags=None):
|
||||||
|
data = self.lines.pop(0) + b'\r\n'
|
||||||
|
return data
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def settimeout(self, timeout):
|
||||||
|
if timeout is None:
|
||||||
|
self.timeout = _defaulttimeout
|
||||||
|
else:
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def gettimeout(self):
|
||||||
|
return self.timeout
|
||||||
|
|
||||||
|
def setsockopt(self, level, optname, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def getsockopt(self, level, optname, buflen=None):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def bind(self, address):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def accept(self):
|
||||||
|
self.conn = MockSocket()
|
||||||
|
return self.conn, 'c'
|
||||||
|
|
||||||
|
def getsockname(self):
|
||||||
|
return ('0.0.0.0', 0)
|
||||||
|
|
||||||
|
def setblocking(self, flag):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def listen(self, backlog):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def makefile(self, mode='r', bufsize=-1):
|
||||||
|
handle = MockFile(self.lines)
|
||||||
|
return handle
|
||||||
|
|
||||||
|
def sendall(self, buffer, flags=None):
|
||||||
|
self.last = data
|
||||||
|
self.output.append(data)
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
def send(self, data, flags=None):
|
||||||
|
self.last = data
|
||||||
|
self.output.append(data)
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
def getpeername(self):
|
||||||
|
return 'peer'
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def socket(family=None, type=None, proto=None):
|
||||||
|
return MockSocket()
|
||||||
|
|
||||||
|
|
||||||
|
def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT):
|
||||||
|
try:
|
||||||
|
int_port = int(address[1])
|
||||||
|
except ValueError:
|
||||||
|
raise error
|
||||||
|
ms = MockSocket()
|
||||||
|
if timeout is socket_module._GLOBAL_DEFAULT_TIMEOUT:
|
||||||
|
timeout = getdefaulttimeout()
|
||||||
|
ms.settimeout(timeout)
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
def setdefaulttimeout(timeout):
|
||||||
|
global _defaulttimeout
|
||||||
|
_defaulttimeout = timeout
|
||||||
|
|
||||||
|
|
||||||
|
def getdefaulttimeout():
|
||||||
|
return _defaulttimeout
|
||||||
|
|
||||||
|
|
||||||
|
def getfqdn():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def gethostname():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def gethostbyname(name):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class gaierror(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class error(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
AF_INET = None
|
||||||
|
SOCK_STREAM = None
|
||||||
|
SOL_SOCKET = None
|
||||||
|
SO_REUSEADDR = None
|
||||||
|
|
@ -1,53 +1,16 @@
|
||||||
import asynchat
|
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
from test import support, mock_socket
|
||||||
import socket
|
import socket
|
||||||
from test import support
|
|
||||||
import asyncore
|
|
||||||
import io
|
import io
|
||||||
import smtpd
|
import smtpd
|
||||||
|
import asyncore
|
||||||
|
|
||||||
# mock-ish socket to sit underneath asyncore
|
|
||||||
class DummySocket:
|
|
||||||
def __init__(self):
|
|
||||||
self.output = []
|
|
||||||
self.queue = []
|
|
||||||
self.conn = None
|
|
||||||
def queue_recv(self, line):
|
|
||||||
self.queue.append(line)
|
|
||||||
def recv(self, *args):
|
|
||||||
data = self.queue.pop(0) + b'\r\n'
|
|
||||||
return data
|
|
||||||
def fileno(self):
|
|
||||||
return 0
|
|
||||||
def setsockopt(self, *args):
|
|
||||||
pass
|
|
||||||
def getsockopt(self, *args):
|
|
||||||
return 0
|
|
||||||
def bind(self, *args):
|
|
||||||
pass
|
|
||||||
def accept(self):
|
|
||||||
self.conn = DummySocket()
|
|
||||||
return self.conn, 'c'
|
|
||||||
def listen(self, *args):
|
|
||||||
pass
|
|
||||||
def setblocking(self, *args):
|
|
||||||
pass
|
|
||||||
def send(self, data):
|
|
||||||
self.last = data
|
|
||||||
self.output.append(data)
|
|
||||||
return len(data)
|
|
||||||
def getpeername(self):
|
|
||||||
return 'peer'
|
|
||||||
def close(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class DummyServer(smtpd.SMTPServer):
|
class DummyServer(smtpd.SMTPServer):
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
smtpd.SMTPServer.__init__(self, *args)
|
smtpd.SMTPServer.__init__(self, *args)
|
||||||
self.messages = []
|
self.messages = []
|
||||||
def create_socket(self, family, type):
|
|
||||||
self.family_and_type = (socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
self.set_socket(DummySocket())
|
|
||||||
def process_message(self, peer, mailfrom, rcpttos, data):
|
def process_message(self, peer, mailfrom, rcpttos, data):
|
||||||
self.messages.append((peer, mailfrom, rcpttos, data))
|
self.messages.append((peer, mailfrom, rcpttos, data))
|
||||||
if data == 'return status':
|
if data == 'return status':
|
||||||
|
|
@ -62,11 +25,15 @@ class BrokenDummyServer(DummyServer):
|
||||||
|
|
||||||
class SMTPDChannelTest(TestCase):
|
class SMTPDChannelTest(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
smtpd.socket = asyncore.socket = mock_socket
|
||||||
self.debug = smtpd.DEBUGSTREAM = io.StringIO()
|
self.debug = smtpd.DEBUGSTREAM = io.StringIO()
|
||||||
self.server = DummyServer('a', 'b')
|
self.server = DummyServer('a', 'b')
|
||||||
conn, addr = self.server.accept()
|
conn, addr = self.server.accept()
|
||||||
self.channel = smtpd.SMTPChannel(self.server, conn, addr)
|
self.channel = smtpd.SMTPChannel(self.server, conn, addr)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
asyncore.socket = smtpd.socket = socket
|
||||||
|
|
||||||
def write_line(self, line):
|
def write_line(self, line):
|
||||||
self.channel.socket.queue_recv(line)
|
self.channel.socket.queue_recv(line)
|
||||||
self.channel.handle_read()
|
self.channel.handle_read()
|
||||||
|
|
@ -88,7 +55,7 @@ class SMTPDChannelTest(TestCase):
|
||||||
b'502 Error: command "EHLO" not implemented\r\n')
|
b'502 Error: command "EHLO" not implemented\r\n')
|
||||||
|
|
||||||
def test_HELO(self):
|
def test_HELO(self):
|
||||||
name = socket.getfqdn()
|
name = smtpd.socket.getfqdn()
|
||||||
self.write_line(b'HELO test.example')
|
self.write_line(b'HELO test.example')
|
||||||
self.assertEqual(self.channel.socket.last,
|
self.assertEqual(self.channel.socket.last,
|
||||||
'250 {}\r\n'.format(name).encode('ascii'))
|
'250 {}\r\n'.format(name).encode('ascii'))
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import time
|
||||||
import select
|
import select
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from test import support
|
from test import support, mock_socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -48,27 +48,17 @@ def server(evt, buf, serv):
|
||||||
serv.close()
|
serv.close()
|
||||||
evt.set()
|
evt.set()
|
||||||
|
|
||||||
@unittest.skipUnless(threading, 'Threading required for this test.')
|
|
||||||
class GeneralTests(unittest.TestCase):
|
class GeneralTests(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._threads = support.threading_setup()
|
smtplib.socket = mock_socket
|
||||||
self.evt = threading.Event()
|
self.port = 25
|
||||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
self.sock.settimeout(15)
|
|
||||||
self.port = support.bind_port(self.sock)
|
|
||||||
servargs = (self.evt, b"220 Hola mundo\n", self.sock)
|
|
||||||
self.thread = threading.Thread(target=server, args=servargs)
|
|
||||||
self.thread.start()
|
|
||||||
self.evt.wait()
|
|
||||||
self.evt.clear()
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.evt.wait()
|
smtplib.socket = socket
|
||||||
self.thread.join()
|
|
||||||
support.threading_cleanup(*self._threads)
|
|
||||||
|
|
||||||
def testBasic1(self):
|
def testBasic1(self):
|
||||||
|
mock_socket.reply_with(b"220 Hola mundo")
|
||||||
# connects
|
# connects
|
||||||
smtp = smtplib.SMTP(HOST, self.port)
|
smtp = smtplib.SMTP(HOST, self.port)
|
||||||
smtp.close()
|
smtp.close()
|
||||||
|
|
@ -85,12 +75,13 @@ class GeneralTests(unittest.TestCase):
|
||||||
smtp.close()
|
smtp.close()
|
||||||
|
|
||||||
def testTimeoutDefault(self):
|
def testTimeoutDefault(self):
|
||||||
self.assertTrue(socket.getdefaulttimeout() is None)
|
self.assertTrue(mock_socket.getdefaulttimeout() is None)
|
||||||
socket.setdefaulttimeout(30)
|
mock_socket.setdefaulttimeout(30)
|
||||||
|
self.assertEqual(mock_socket.getdefaulttimeout(), 30)
|
||||||
try:
|
try:
|
||||||
smtp = smtplib.SMTP(HOST, self.port)
|
smtp = smtplib.SMTP(HOST, self.port)
|
||||||
finally:
|
finally:
|
||||||
socket.setdefaulttimeout(None)
|
mock_socket.setdefaulttimeout(None)
|
||||||
self.assertEqual(smtp.sock.gettimeout(), 30)
|
self.assertEqual(smtp.sock.gettimeout(), 30)
|
||||||
smtp.close()
|
smtp.close()
|
||||||
|
|
||||||
|
|
@ -155,6 +146,8 @@ MSG_END = '------------ END MESSAGE ------------\n'
|
||||||
class DebuggingServerTests(unittest.TestCase):
|
class DebuggingServerTests(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
self.real_getfqdn = socket.getfqdn
|
||||||
|
socket.getfqdn = mock_socket.getfqdn
|
||||||
# temporarily replace sys.stdout to capture DebuggingServer output
|
# temporarily replace sys.stdout to capture DebuggingServer output
|
||||||
self.old_stdout = sys.stdout
|
self.old_stdout = sys.stdout
|
||||||
self.output = io.StringIO()
|
self.output = io.StringIO()
|
||||||
|
|
@ -176,6 +169,7 @@ class DebuggingServerTests(unittest.TestCase):
|
||||||
self.serv_evt.clear()
|
self.serv_evt.clear()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
socket.getfqdn = self.real_getfqdn
|
||||||
# indicate that the client is finished
|
# indicate that the client is finished
|
||||||
self.client_evt.set()
|
self.client_evt.set()
|
||||||
# wait for the server thread to terminate
|
# wait for the server thread to terminate
|
||||||
|
|
@ -251,6 +245,12 @@ class DebuggingServerTests(unittest.TestCase):
|
||||||
|
|
||||||
class NonConnectingTests(unittest.TestCase):
|
class NonConnectingTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
smtplib.socket = mock_socket
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
smtplib.socket = socket
|
||||||
|
|
||||||
def testNotConnected(self):
|
def testNotConnected(self):
|
||||||
# Test various operations on an unconnected SMTP object that
|
# Test various operations on an unconnected SMTP object that
|
||||||
# should raise exceptions (at present the attempt in SMTP.send
|
# should raise exceptions (at present the attempt in SMTP.send
|
||||||
|
|
@ -263,9 +263,9 @@ class NonConnectingTests(unittest.TestCase):
|
||||||
|
|
||||||
def testNonnumericPort(self):
|
def testNonnumericPort(self):
|
||||||
# check that non-numeric port raises socket.error
|
# check that non-numeric port raises socket.error
|
||||||
self.assertRaises(socket.error, smtplib.SMTP,
|
self.assertRaises(mock_socket.error, smtplib.SMTP,
|
||||||
"localhost", "bogus")
|
"localhost", "bogus")
|
||||||
self.assertRaises(socket.error, smtplib.SMTP,
|
self.assertRaises(mock_socket.error, smtplib.SMTP,
|
||||||
"localhost:bogus")
|
"localhost:bogus")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -274,25 +274,15 @@ class NonConnectingTests(unittest.TestCase):
|
||||||
class BadHELOServerTests(unittest.TestCase):
|
class BadHELOServerTests(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
smtplib.socket = mock_socket
|
||||||
|
mock_socket.reply_with(b"199 no hello for you!")
|
||||||
self.old_stdout = sys.stdout
|
self.old_stdout = sys.stdout
|
||||||
self.output = io.StringIO()
|
self.output = io.StringIO()
|
||||||
sys.stdout = self.output
|
sys.stdout = self.output
|
||||||
|
self.port = 25
|
||||||
self._threads = support.threading_setup()
|
|
||||||
self.evt = threading.Event()
|
|
||||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
self.sock.settimeout(15)
|
|
||||||
self.port = support.bind_port(self.sock)
|
|
||||||
servargs = (self.evt, b"199 no hello for you!\n", self.sock)
|
|
||||||
self.thread = threading.Thread(target=server, args=servargs)
|
|
||||||
self.thread.start()
|
|
||||||
self.evt.wait()
|
|
||||||
self.evt.clear()
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.evt.wait()
|
smtplib.socket = socket
|
||||||
self.thread.join()
|
|
||||||
support.threading_cleanup(*self._threads)
|
|
||||||
sys.stdout = self.old_stdout
|
sys.stdout = self.old_stdout
|
||||||
|
|
||||||
def testFailingHELO(self):
|
def testFailingHELO(self):
|
||||||
|
|
@ -405,6 +395,8 @@ class SimSMTPServer(smtpd.SMTPServer):
|
||||||
class SMTPSimTests(unittest.TestCase):
|
class SMTPSimTests(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
self.real_getfqdn = socket.getfqdn
|
||||||
|
socket.getfqdn = mock_socket.getfqdn
|
||||||
self._threads = support.threading_setup()
|
self._threads = support.threading_setup()
|
||||||
self.serv_evt = threading.Event()
|
self.serv_evt = threading.Event()
|
||||||
self.client_evt = threading.Event()
|
self.client_evt = threading.Event()
|
||||||
|
|
@ -421,6 +413,7 @@ class SMTPSimTests(unittest.TestCase):
|
||||||
self.serv_evt.clear()
|
self.serv_evt.clear()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
socket.getfqdn = self.real_getfqdn
|
||||||
# indicate that the client is finished
|
# indicate that the client is finished
|
||||||
self.client_evt.set()
|
self.client_evt.set()
|
||||||
# wait for the server thread to terminate
|
# wait for the server thread to terminate
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue