Add support for PostgreSQL client certificates in dbshell

Enable dbshell to use client SSL certificates and keys by passing PGSSLMODE, PGSSLROOTCERT, PGSSLCERT, and PGSSLKEY environment variables. This facilitates mutual TLS connections for improved database security.

Addresses issue #28322.
This commit is contained in:
utkarsh.arya@zomato.com 2025-11-15 23:10:01 +00:00
parent d87bd29c4f
commit c06b63c8cb
2 changed files with 153 additions and 42 deletions

View file

@ -17,6 +17,10 @@ class DatabaseClient(BaseDatabaseClient):
dbname = conn_params.get('database', '')
user = conn_params.get('user', '')
passwd = conn_params.get('password', '')
sslmode = conn_params.get('sslmode', '')
sslrootcert = conn_params.get('sslrootcert', '')
sslcert = conn_params.get('sslcert', '')
sslkey = conn_params.get('sslkey', '')
if user:
args += ['-U', user]
@ -30,6 +34,14 @@ class DatabaseClient(BaseDatabaseClient):
subprocess_env = os.environ.copy()
if passwd:
subprocess_env['PGPASSWORD'] = str(passwd)
if sslmode:
subprocess_env['PGSSLMODE'] = str(sslmode)
if sslrootcert:
subprocess_env['PGSSLROOTCERT'] = str(sslrootcert)
if sslcert:
subprocess_env['PGSSLCERT'] = str(sslcert)
if sslkey:
subprocess_env['PGSSLKEY'] = str(sslkey)
try:
# Allow SIGINT to pass to psql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN)

View file

@ -12,74 +12,105 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
def _run_it(self, dbinfo):
"""
That function invokes the runshell command, while mocking
subprocess.run(). It returns a 2-tuple with:
subprocess.run(). It returns a tuple with:
- The command line list
- The the value of the PGPASSWORD environment variable, or None.
- The value of the PGPASSWORD environment variable, or None.
- A dict of SSL-related environment variables.
"""
def _mock_subprocess_run(*args, env=os.environ, **kwargs):
self.subprocess_args = list(*args)
self.pgpassword = env.get('PGPASSWORD')
self.ssl_env = {
'PGSSLMODE': env.get('PGSSLMODE'),
'PGSSLROOTCERT': env.get('PGSSLROOTCERT'),
'PGSSLCERT': env.get('PGSSLCERT'),
'PGSSLKEY': env.get('PGSSLKEY'),
}
return subprocess.CompletedProcess(self.subprocess_args, 0)
with mock.patch('subprocess.run', new=_mock_subprocess_run):
DatabaseClient.runshell_db(dbinfo)
return self.subprocess_args, self.pgpassword
return self.subprocess_args, self.pgpassword, self.ssl_env
def test_basic(self):
args, password, ssl_env = self._run_it({
'database': 'dbname',
'user': 'someuser',
'password': 'somepassword',
'host': 'somehost',
'port': '444',
})
self.assertEqual(
self._run_it({
'database': 'dbname',
'user': 'someuser',
'password': 'somepassword',
'host': 'somehost',
'port': '444',
}), (
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
'somepassword',
)
args,
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname']
)
self.assertEqual(password, 'somepassword')
self.assertEqual(ssl_env, {
'PGSSLMODE': None,
'PGSSLROOTCERT': None,
'PGSSLCERT': None,
'PGSSLKEY': None,
})
def test_nopass(self):
args, password, ssl_env = self._run_it({
'database': 'dbname',
'user': 'someuser',
'host': 'somehost',
'port': '444',
})
self.assertEqual(
self._run_it({
'database': 'dbname',
'user': 'someuser',
'host': 'somehost',
'port': '444',
}), (
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
None,
)
args,
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname']
)
self.assertEqual(password, None)
self.assertEqual(ssl_env, {
'PGSSLMODE': None,
'PGSSLROOTCERT': None,
'PGSSLCERT': None,
'PGSSLKEY': None,
})
def test_column(self):
args, password, ssl_env = self._run_it({
'database': 'dbname',
'user': 'some:user',
'password': 'some:password',
'host': '::1',
'port': '444',
})
self.assertEqual(
self._run_it({
'database': 'dbname',
'user': 'some:user',
'password': 'some:password',
'host': '::1',
'port': '444',
}), (
['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'],
'some:password',
)
args,
['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname']
)
self.assertEqual(password, 'some:password')
self.assertEqual(ssl_env, {
'PGSSLMODE': None,
'PGSSLROOTCERT': None,
'PGSSLCERT': None,
'PGSSLKEY': None,
})
def test_accent(self):
username = 'rôle'
password = 'sésame'
args, passwd, ssl_env = self._run_it({
'database': 'dbname',
'user': username,
'password': password,
'host': 'somehost',
'port': '444',
})
self.assertEqual(
self._run_it({
'database': 'dbname',
'user': username,
'password': password,
'host': 'somehost',
'port': '444',
}), (
['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'],
password,
)
args,
['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname']
)
self.assertEqual(passwd, password)
self.assertEqual(ssl_env, {
'PGSSLMODE': None,
'PGSSLROOTCERT': None,
'PGSSLCERT': None,
'PGSSLKEY': None,
})
def test_sigint_handler(self):
"""SIGINT is ignored in Python and passed to psql to abort quries."""
@ -94,3 +125,71 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
DatabaseClient.runshell_db({})
# dbshell restores the original handler.
self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT))
def test_ssl_certificate(self):
"""SSL certificates are passed via environment variables."""
args, password, ssl_env = self._run_it({
'database': 'dbname',
'user': 'someuser',
'password': 'somepassword',
'host': 'somehost',
'port': '444',
'sslmode': 'verify-ca',
'sslrootcert': '/path/to/ca.crt',
'sslcert': '/path/to/client.crt',
'sslkey': '/path/to/client.key',
})
self.assertEqual(
args,
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname']
)
self.assertEqual(password, 'somepassword')
self.assertEqual(ssl_env, {
'PGSSLMODE': 'verify-ca',
'PGSSLROOTCERT': '/path/to/ca.crt',
'PGSSLCERT': '/path/to/client.crt',
'PGSSLKEY': '/path/to/client.key',
})
def test_ssl_mode_only(self):
"""Only sslmode can be specified without other SSL options."""
args, password, ssl_env = self._run_it({
'database': 'dbname',
'user': 'someuser',
'host': 'somehost',
'port': '444',
'sslmode': 'require',
})
self.assertEqual(
args,
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname']
)
self.assertEqual(password, None)
self.assertEqual(ssl_env, {
'PGSSLMODE': 'require',
'PGSSLROOTCERT': None,
'PGSSLCERT': None,
'PGSSLKEY': None,
})
def test_ssl_partial_options(self):
"""Partial SSL options can be specified."""
args, password, ssl_env = self._run_it({
'database': 'dbname',
'user': 'someuser',
'host': 'somehost',
'port': '444',
'sslrootcert': '/path/to/ca.crt',
'sslcert': '/path/to/client.crt',
})
self.assertEqual(
args,
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname']
)
self.assertEqual(password, None)
self.assertEqual(ssl_env, {
'PGSSLMODE': None,
'PGSSLROOTCERT': '/path/to/ca.crt',
'PGSSLCERT': '/path/to/client.crt',
'PGSSLKEY': None,
})