mirror of
https://github.com/python/cpython.git
synced 2025-08-03 16:39:00 +00:00
Issue #24291: Avoid WSGIRequestHandler doing partial writes
If the underlying send() method indicates a partial write, such as when the call is interrupted to handle a signal, the server would silently drop the remaining data. Also add deprecated support for SimpleHandler.stdout.write() doing partial writes.
This commit is contained in:
parent
889f914edb
commit
ed0425c60a
5 changed files with 111 additions and 7 deletions
|
@ -1,18 +1,22 @@
|
|||
from unittest import mock
|
||||
from test import support
|
||||
from test.test_httpservers import NoLogRequestHandler
|
||||
from unittest import TestCase
|
||||
from wsgiref.util import setup_testing_defaults
|
||||
from wsgiref.headers import Headers
|
||||
from wsgiref.handlers import BaseHandler, BaseCGIHandler
|
||||
from wsgiref.handlers import BaseHandler, BaseCGIHandler, SimpleHandler
|
||||
from wsgiref import util
|
||||
from wsgiref.validate import validator
|
||||
from wsgiref.simple_server import WSGIServer, WSGIRequestHandler
|
||||
from wsgiref.simple_server import make_server
|
||||
from http.client import HTTPConnection
|
||||
from io import StringIO, BytesIO, BufferedReader
|
||||
from socketserver import BaseServer
|
||||
from platform import python_implementation
|
||||
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
@ -245,6 +249,56 @@ class IntegrationTests(TestCase):
|
|||
],
|
||||
out.splitlines())
|
||||
|
||||
def test_interrupted_write(self):
|
||||
# BaseHandler._write() and _flush() have to write all data, even if
|
||||
# it takes multiple send() calls. Test this by interrupting a send()
|
||||
# call with a Unix signal.
|
||||
threading = support.import_module("threading")
|
||||
pthread_kill = support.get_attribute(signal, "pthread_kill")
|
||||
|
||||
def app(environ, start_response):
|
||||
start_response("200 OK", [])
|
||||
return [bytes(support.SOCK_MAX_SIZE)]
|
||||
|
||||
class WsgiHandler(NoLogRequestHandler, WSGIRequestHandler):
|
||||
pass
|
||||
|
||||
server = make_server(support.HOST, 0, app, handler_class=WsgiHandler)
|
||||
self.addCleanup(server.server_close)
|
||||
interrupted = threading.Event()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
interrupted.set()
|
||||
|
||||
original = signal.signal(signal.SIGUSR1, signal_handler)
|
||||
self.addCleanup(signal.signal, signal.SIGUSR1, original)
|
||||
received = None
|
||||
main_thread = threading.get_ident()
|
||||
|
||||
def run_client():
|
||||
http = HTTPConnection(*server.server_address)
|
||||
http.request("GET", "/")
|
||||
with http.getresponse() as response:
|
||||
response.read(100)
|
||||
# The main thread should now be blocking in a send() system
|
||||
# call. But in theory, it could get interrupted by other
|
||||
# signals, and then retried. So keep sending the signal in a
|
||||
# loop, in case an earlier signal happens to be delivered at
|
||||
# an inconvenient moment.
|
||||
while True:
|
||||
pthread_kill(main_thread, signal.SIGUSR1)
|
||||
if interrupted.wait(timeout=float(1)):
|
||||
break
|
||||
nonlocal received
|
||||
received = len(response.read())
|
||||
http.close()
|
||||
|
||||
background = threading.Thread(target=run_client)
|
||||
background.start()
|
||||
server.handle_request()
|
||||
background.join()
|
||||
self.assertEqual(received, support.SOCK_MAX_SIZE - 100)
|
||||
|
||||
|
||||
class UtilityTests(TestCase):
|
||||
|
||||
|
@ -701,6 +755,31 @@ class HandlerTests(TestCase):
|
|||
h.run(error_app)
|
||||
self.assertEqual(side_effects['close_called'], True)
|
||||
|
||||
def testPartialWrite(self):
|
||||
written = bytearray()
|
||||
|
||||
class PartialWriter:
|
||||
def write(self, b):
|
||||
partial = b[:7]
|
||||
written.extend(partial)
|
||||
return len(partial)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
environ = {"SERVER_PROTOCOL": "HTTP/1.0"}
|
||||
h = SimpleHandler(BytesIO(), PartialWriter(), sys.stderr, environ)
|
||||
msg = "should not do partial writes"
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
h.run(hello_app)
|
||||
self.assertEqual(b"HTTP/1.0 200 OK\r\n"
|
||||
b"Content-Type: text/plain\r\n"
|
||||
b"Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n"
|
||||
b"Content-Length: 13\r\n"
|
||||
b"\r\n"
|
||||
b"Hello, world!",
|
||||
written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue