mirror of
				https://github.com/django/django.git
				synced 2025-11-03 21:25:09 +00:00 
			
		
		
		
	Fixed #20038 -- Better error message for host validation.
This commit is contained in:
		
							parent
							
								
									c8deaa9e7b
								
							
						
					
					
						commit
						c250f9c99b
					
				
					 2 changed files with 87 additions and 25 deletions
				
			
		| 
						 | 
				
			
			@ -4,7 +4,6 @@ import copy
 | 
			
		|||
import os
 | 
			
		||||
import re
 | 
			
		||||
import sys
 | 
			
		||||
import warnings
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from pprint import pformat
 | 
			
		||||
try:
 | 
			
		||||
| 
						 | 
				
			
			@ -66,11 +65,14 @@ class HttpRequest(object):
 | 
			
		|||
                host = '%s:%s' % (host, server_port)
 | 
			
		||||
 | 
			
		||||
        allowed_hosts = ['*'] if settings.DEBUG else settings.ALLOWED_HOSTS
 | 
			
		||||
        if validate_host(host, allowed_hosts):
 | 
			
		||||
        domain, port = split_domain_port(host)
 | 
			
		||||
        if domain and validate_host(domain, allowed_hosts):
 | 
			
		||||
            return host
 | 
			
		||||
        else:
 | 
			
		||||
            raise SuspiciousOperation(
 | 
			
		||||
                "Invalid HTTP_HOST header (you may need to set ALLOWED_HOSTS): %s" % host)
 | 
			
		||||
            msg = "Invalid HTTP_HOST header: %r." % host
 | 
			
		||||
            if domain:
 | 
			
		||||
                msg += "You may need to add %r to ALLOWED_HOSTS." % domain
 | 
			
		||||
            raise SuspiciousOperation(msg)
 | 
			
		||||
 | 
			
		||||
    def get_full_path(self):
 | 
			
		||||
        # RFC 3986 requires query string arguments to be in the ASCII range.
 | 
			
		||||
| 
						 | 
				
			
			@ -454,9 +456,30 @@ def bytes_to_text(s, encoding):
 | 
			
		|||
        return s
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_domain_port(host):
 | 
			
		||||
    """
 | 
			
		||||
    Return a (domain, port) tuple from a given host.
 | 
			
		||||
 | 
			
		||||
    Returned domain is lower-cased. If the host is invalid, the domain will be
 | 
			
		||||
    empty.
 | 
			
		||||
    """
 | 
			
		||||
    host = host.lower()
 | 
			
		||||
 | 
			
		||||
    if not host_validation_re.match(host):
 | 
			
		||||
        return '', ''
 | 
			
		||||
 | 
			
		||||
    if host[-1] == ']':
 | 
			
		||||
        # It's an IPv6 address without a port.
 | 
			
		||||
        return host, ''
 | 
			
		||||
    bits = host.rsplit(':', 1)
 | 
			
		||||
    if len(bits) == 2:
 | 
			
		||||
        return tuple(bits)
 | 
			
		||||
    return bits[0], ''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_host(host, allowed_hosts):
 | 
			
		||||
    """
 | 
			
		||||
    Validate the given host header value for this site.
 | 
			
		||||
    Validate the given host for this site.
 | 
			
		||||
 | 
			
		||||
    Check that the host looks valid and matches a host or host pattern in the
 | 
			
		||||
    given list of ``allowed_hosts``. Any pattern beginning with a period
 | 
			
		||||
| 
						 | 
				
			
			@ -464,31 +487,20 @@ def validate_host(host, allowed_hosts):
 | 
			
		|||
    ``example.com`` and any subdomain), ``*`` matches anything, and anything
 | 
			
		||||
    else must match exactly.
 | 
			
		||||
 | 
			
		||||
    Note: This function assumes that the given host is lower-cased and has
 | 
			
		||||
    already had the port, if any, stripped off.
 | 
			
		||||
 | 
			
		||||
    Return ``True`` for a valid host, ``False`` otherwise.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    # All validation is case-insensitive
 | 
			
		||||
    host = host.lower()
 | 
			
		||||
 | 
			
		||||
    # Basic sanity check
 | 
			
		||||
    if not host_validation_re.match(host):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    # Validate only the domain part.
 | 
			
		||||
    if host[-1] == ']':
 | 
			
		||||
        # It's an IPv6 address without a port.
 | 
			
		||||
        domain = host
 | 
			
		||||
    else:
 | 
			
		||||
        domain = host.rsplit(':', 1)[0]
 | 
			
		||||
 | 
			
		||||
    for pattern in allowed_hosts:
 | 
			
		||||
        pattern = pattern.lower()
 | 
			
		||||
        match = (
 | 
			
		||||
            pattern == '*' or
 | 
			
		||||
            pattern.startswith('.') and (
 | 
			
		||||
                domain.endswith(pattern) or domain == pattern[1:]
 | 
			
		||||
                host.endswith(pattern) or host == pattern[1:]
 | 
			
		||||
                ) or
 | 
			
		||||
            pattern == domain
 | 
			
		||||
            pattern == host
 | 
			
		||||
            )
 | 
			
		||||
        if match:
 | 
			
		||||
            return True
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,16 +11,16 @@ from django.core import signals
 | 
			
		|||
from django.core.exceptions import SuspiciousOperation
 | 
			
		||||
from django.core.handlers.wsgi import WSGIRequest, LimitedStream
 | 
			
		||||
from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError
 | 
			
		||||
from django.test import TransactionTestCase
 | 
			
		||||
from django.test import SimpleTestCase, TransactionTestCase
 | 
			
		||||
from django.test.client import FakePayload
 | 
			
		||||
from django.test.utils import override_settings, str_prefix
 | 
			
		||||
from django.utils import six
 | 
			
		||||
from django.utils import unittest
 | 
			
		||||
from django.utils.unittest import skipIf
 | 
			
		||||
from django.utils.http import cookie_date, urlencode
 | 
			
		||||
from django.utils.timezone import utc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RequestsTests(unittest.TestCase):
 | 
			
		||||
class RequestsTests(SimpleTestCase):
 | 
			
		||||
    def test_httprequest(self):
 | 
			
		||||
        request = HttpRequest()
 | 
			
		||||
        self.assertEqual(list(request.GET.keys()), [])
 | 
			
		||||
| 
						 | 
				
			
			@ -287,6 +287,56 @@ class RequestsTests(unittest.TestCase):
 | 
			
		|||
        self.assertEqual(request.get_host(), 'example.com')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @override_settings(ALLOWED_HOSTS=[])
 | 
			
		||||
    def test_get_host_suggestion_of_allowed_host(self):
 | 
			
		||||
        """get_host() makes helpful suggestions if a valid-looking host is not in ALLOWED_HOSTS."""
 | 
			
		||||
        msg_invalid_host = "Invalid HTTP_HOST header: %r."
 | 
			
		||||
        msg_suggestion = msg_invalid_host + "You may need to add %r to ALLOWED_HOSTS."
 | 
			
		||||
 | 
			
		||||
        for host in [ # Valid-looking hosts
 | 
			
		||||
            'example.com',
 | 
			
		||||
            '12.34.56.78',
 | 
			
		||||
            '[2001:19f0:feee::dead:beef:cafe]',
 | 
			
		||||
            'xn--4ca9at.com', # Punnycode for öäü.com
 | 
			
		||||
        ]:
 | 
			
		||||
            request = HttpRequest()
 | 
			
		||||
            request.META = {'HTTP_HOST': host}
 | 
			
		||||
            self.assertRaisesMessage(
 | 
			
		||||
                SuspiciousOperation,
 | 
			
		||||
                msg_suggestion % (host, host),
 | 
			
		||||
                request.get_host
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for domain, port in [ # Valid-looking hosts with a port number
 | 
			
		||||
            ('example.com', 80),
 | 
			
		||||
            ('12.34.56.78', 443),
 | 
			
		||||
            ('[2001:19f0:feee::dead:beef:cafe]', 8080),
 | 
			
		||||
        ]:
 | 
			
		||||
            host = '%s:%s' % (domain, port)
 | 
			
		||||
            request = HttpRequest()
 | 
			
		||||
            request.META = {'HTTP_HOST': host}
 | 
			
		||||
            self.assertRaisesMessage(
 | 
			
		||||
                SuspiciousOperation,
 | 
			
		||||
                msg_suggestion % (host, domain),
 | 
			
		||||
                request.get_host
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for host in [ # Invalid hosts
 | 
			
		||||
            'example.com@evil.tld',
 | 
			
		||||
            'example.com:dr.frankenstein@evil.tld',
 | 
			
		||||
            'example.com:dr.frankenstein@evil.tld:80',
 | 
			
		||||
            'example.com:80/badpath',
 | 
			
		||||
            'example.com: recovermypassword.com',
 | 
			
		||||
        ]:
 | 
			
		||||
            request = HttpRequest()
 | 
			
		||||
            request.META = {'HTTP_HOST': host}
 | 
			
		||||
            self.assertRaisesMessage(
 | 
			
		||||
                SuspiciousOperation,
 | 
			
		||||
                msg_invalid_host % host,
 | 
			
		||||
                request.get_host
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_near_expiration(self):
 | 
			
		||||
        "Cookie will expire when an near expiration time is provided"
 | 
			
		||||
        response = HttpResponse()
 | 
			
		||||
| 
						 | 
				
			
			@ -587,7 +637,7 @@ class RequestsTests(unittest.TestCase):
 | 
			
		|||
            request.body
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@unittest.skipIf(connection.vendor == 'sqlite'
 | 
			
		||||
@skipIf(connection.vendor == 'sqlite'
 | 
			
		||||
        and connection.settings_dict['NAME'] in ('', ':memory:'),
 | 
			
		||||
        "Cannot establish two connections to an in-memory SQLite database.")
 | 
			
		||||
class DatabaseConnectionHandlingTests(TransactionTestCase):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue