From c06b63c8cb3dabda7f34d02b7de575f88addce44 Mon Sep 17 00:00:00 2001 From: "utkarsh.arya@zomato.com" Date: Sat, 15 Nov 2025 23:10:01 +0000 Subject: [PATCH] Add support for PostgreSQL client certificates in dbshell Enable dbshell to use client SSL certificates and keys by passing PGSSLMODE, PGSSLROOTCERT, PGSSLCERT, and PGSSLKEY environment variables. This facilitates mutual TLS connections for improved database security. Addresses issue #28322. --- django/db/backends/postgresql/client.py | 12 ++ tests/dbshell/test_postgresql.py | 183 ++++++++++++++++++------ 2 files changed, 153 insertions(+), 42 deletions(-) diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py index cf4df76882..0efe0d47f0 100644 --- a/django/db/backends/postgresql/client.py +++ b/django/db/backends/postgresql/client.py @@ -17,6 +17,10 @@ class DatabaseClient(BaseDatabaseClient): dbname = conn_params.get('database', '') user = conn_params.get('user', '') passwd = conn_params.get('password', '') + sslmode = conn_params.get('sslmode', '') + sslrootcert = conn_params.get('sslrootcert', '') + sslcert = conn_params.get('sslcert', '') + sslkey = conn_params.get('sslkey', '') if user: args += ['-U', user] @@ -30,6 +34,14 @@ class DatabaseClient(BaseDatabaseClient): subprocess_env = os.environ.copy() if passwd: subprocess_env['PGPASSWORD'] = str(passwd) + if sslmode: + subprocess_env['PGSSLMODE'] = str(sslmode) + if sslrootcert: + subprocess_env['PGSSLROOTCERT'] = str(sslrootcert) + if sslcert: + subprocess_env['PGSSLCERT'] = str(sslcert) + if sslkey: + subprocess_env['PGSSLKEY'] = str(sslkey) try: # Allow SIGINT to pass to psql to abort queries. signal.signal(signal.SIGINT, signal.SIG_IGN) diff --git a/tests/dbshell/test_postgresql.py b/tests/dbshell/test_postgresql.py index a33e7f6482..d4ddccb964 100644 --- a/tests/dbshell/test_postgresql.py +++ b/tests/dbshell/test_postgresql.py @@ -12,74 +12,105 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): def _run_it(self, dbinfo): """ That function invokes the runshell command, while mocking - subprocess.run(). It returns a 2-tuple with: + subprocess.run(). It returns a tuple with: - The command line list - - The the value of the PGPASSWORD environment variable, or None. + - The value of the PGPASSWORD environment variable, or None. + - A dict of SSL-related environment variables. """ def _mock_subprocess_run(*args, env=os.environ, **kwargs): self.subprocess_args = list(*args) self.pgpassword = env.get('PGPASSWORD') + self.ssl_env = { + 'PGSSLMODE': env.get('PGSSLMODE'), + 'PGSSLROOTCERT': env.get('PGSSLROOTCERT'), + 'PGSSLCERT': env.get('PGSSLCERT'), + 'PGSSLKEY': env.get('PGSSLKEY'), + } return subprocess.CompletedProcess(self.subprocess_args, 0) with mock.patch('subprocess.run', new=_mock_subprocess_run): DatabaseClient.runshell_db(dbinfo) - return self.subprocess_args, self.pgpassword + return self.subprocess_args, self.pgpassword, self.ssl_env def test_basic(self): + args, password, ssl_env = self._run_it({ + 'database': 'dbname', + 'user': 'someuser', + 'password': 'somepassword', + 'host': 'somehost', + 'port': '444', + }) self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': 'someuser', - 'password': 'somepassword', - 'host': 'somehost', - 'port': '444', - }), ( - ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], - 'somepassword', - ) + args, + ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'] ) + self.assertEqual(password, 'somepassword') + self.assertEqual(ssl_env, { + 'PGSSLMODE': None, + 'PGSSLROOTCERT': None, + 'PGSSLCERT': None, + 'PGSSLKEY': None, + }) def test_nopass(self): + args, password, ssl_env = self._run_it({ + 'database': 'dbname', + 'user': 'someuser', + 'host': 'somehost', + 'port': '444', + }) self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': 'someuser', - 'host': 'somehost', - 'port': '444', - }), ( - ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], - None, - ) + args, + ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'] ) + self.assertEqual(password, None) + self.assertEqual(ssl_env, { + 'PGSSLMODE': None, + 'PGSSLROOTCERT': None, + 'PGSSLCERT': None, + 'PGSSLKEY': None, + }) def test_column(self): + args, password, ssl_env = self._run_it({ + 'database': 'dbname', + 'user': 'some:user', + 'password': 'some:password', + 'host': '::1', + 'port': '444', + }) self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': 'some:user', - 'password': 'some:password', - 'host': '::1', - 'port': '444', - }), ( - ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'], - 'some:password', - ) + args, + ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'] ) + self.assertEqual(password, 'some:password') + self.assertEqual(ssl_env, { + 'PGSSLMODE': None, + 'PGSSLROOTCERT': None, + 'PGSSLCERT': None, + 'PGSSLKEY': None, + }) def test_accent(self): username = 'rôle' password = 'sésame' + args, passwd, ssl_env = self._run_it({ + 'database': 'dbname', + 'user': username, + 'password': password, + 'host': 'somehost', + 'port': '444', + }) self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': username, - 'password': password, - 'host': 'somehost', - 'port': '444', - }), ( - ['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'], - password, - ) + args, + ['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'] ) + self.assertEqual(passwd, password) + self.assertEqual(ssl_env, { + 'PGSSLMODE': None, + 'PGSSLROOTCERT': None, + 'PGSSLCERT': None, + 'PGSSLKEY': None, + }) def test_sigint_handler(self): """SIGINT is ignored in Python and passed to psql to abort quries.""" @@ -94,3 +125,71 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): DatabaseClient.runshell_db({}) # dbshell restores the original handler. self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT)) + + def test_ssl_certificate(self): + """SSL certificates are passed via environment variables.""" + args, password, ssl_env = self._run_it({ + 'database': 'dbname', + 'user': 'someuser', + 'password': 'somepassword', + 'host': 'somehost', + 'port': '444', + 'sslmode': 'verify-ca', + 'sslrootcert': '/path/to/ca.crt', + 'sslcert': '/path/to/client.crt', + 'sslkey': '/path/to/client.key', + }) + self.assertEqual( + args, + ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'] + ) + self.assertEqual(password, 'somepassword') + self.assertEqual(ssl_env, { + 'PGSSLMODE': 'verify-ca', + 'PGSSLROOTCERT': '/path/to/ca.crt', + 'PGSSLCERT': '/path/to/client.crt', + 'PGSSLKEY': '/path/to/client.key', + }) + + def test_ssl_mode_only(self): + """Only sslmode can be specified without other SSL options.""" + args, password, ssl_env = self._run_it({ + 'database': 'dbname', + 'user': 'someuser', + 'host': 'somehost', + 'port': '444', + 'sslmode': 'require', + }) + self.assertEqual( + args, + ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'] + ) + self.assertEqual(password, None) + self.assertEqual(ssl_env, { + 'PGSSLMODE': 'require', + 'PGSSLROOTCERT': None, + 'PGSSLCERT': None, + 'PGSSLKEY': None, + }) + + def test_ssl_partial_options(self): + """Partial SSL options can be specified.""" + args, password, ssl_env = self._run_it({ + 'database': 'dbname', + 'user': 'someuser', + 'host': 'somehost', + 'port': '444', + 'sslrootcert': '/path/to/ca.crt', + 'sslcert': '/path/to/client.crt', + }) + self.assertEqual( + args, + ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'] + ) + self.assertEqual(password, None) + self.assertEqual(ssl_env, { + 'PGSSLMODE': None, + 'PGSSLROOTCERT': '/path/to/ca.crt', + 'PGSSLCERT': '/path/to/client.crt', + 'PGSSLKEY': None, + })