Skip to content
134 changes: 126 additions & 8 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import asyncio
import configparser
import collections
from collections.abc import Callable
import enum
Expand Down Expand Up @@ -87,6 +88,9 @@ class SSLNegotiation(compat.StrEnum):
PGPASSFILE = '.pgpass'


PG_SERVICEFILE = '.pg_service.conf'


def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:

Expand Down Expand Up @@ -268,7 +272,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
password, passfile, database, ssl, service,
direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
# `auth_hosts` is the version of host information for the purposes
Expand All @@ -281,6 +285,118 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if dsn:
parsed = urllib.parse.urlparse(dsn)

query = None
if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]

if 'service' in query:
val = query.pop('service')
if not service and val:
service = val

connection_service_file = os.getenv('PGSERVICEFILE')
if connection_service_file is None:
homedir = compat.get_pg_home_directory()
if homedir:
connection_service_file = homedir / PG_SERVICEFILE
else:
connection_service_file = None
else:
connection_service_file = pathlib.Path(connection_service_file)

if connection_service_file is not None and service is not None:
pg_service = configparser.ConfigParser()
pg_service.read(connection_service_file)
if service in pg_service.sections():
service_params = pg_service[service]
if 'port' in service_params:
val = service_params.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]

if 'host' in service_params:
val = service_params.pop('host')
if not host and val:
host, port = _parse_hostlist(val, port)

if 'dbname' in service_params:
val = service_params.pop('dbname')
if database is None:
database = val

if 'database' in service_params:
val = service_params.pop('database')
if database is None:
database = val

if 'user' in service_params:
val = service_params.pop('user')
if user is None:
user = val

if 'password' in service_params:
val = service_params.pop('password')
if password is None:
password = val

if 'passfile' in service_params:
val = service_params.pop('passfile')
if passfile is None:
passfile = val

if 'sslmode' in service_params:
val = service_params.pop('sslmode')
if ssl is None:
ssl = val

if 'sslcert' in service_params:
sslcert = service_params.pop('sslcert')

if 'sslkey' in service_params:
sslkey = service_params.pop('sslkey')

if 'sslrootcert' in service_params:
sslrootcert = service_params.pop('sslrootcert')

if 'sslnegotiation' in service_params:
sslnegotiation = service_params.pop('sslnegotiation')

if 'sslcrl' in service_params:
sslcrl = service_params.pop('sslcrl')

if 'sslpassword' in service_params:
sslpassword = service_params.pop('sslpassword')

if 'ssl_min_protocol_version' in service_params:
ssl_min_protocol_version = service_params.pop(
'ssl_min_protocol_version'
)

if 'ssl_max_protocol_version' in service_params:
ssl_max_protocol_version = service_params.pop(
'ssl_max_protocol_version'
)

if 'target_session_attrs' in service_params:
dsn_target_session_attrs = service_params.pop(
'target_session_attrs'
)
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs

if 'krbsrvname' in service_params:
val = service_params.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val

if 'gsslib' in service_params:
val = service_params.pop('gsslib')
if gsslib is None:
gsslib = val

if parsed.scheme not in {'postgresql', 'postgres'}:
raise exceptions.ClientConfigurationError(
'invalid DSN: scheme is expected to be either '
Expand Down Expand Up @@ -315,11 +431,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if password is None and dsn_password:
password = urllib.parse.unquote(dsn_password)

if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]
if query:

if 'port' in query:
val = query.pop('port')
Expand Down Expand Up @@ -406,6 +518,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if gsslib is None:
gsslib = val

if 'service' in query:
val = query.pop('service')
if service is None:
service = val

if query:
if server_settings is None:
server_settings = query
Expand Down Expand Up @@ -724,7 +841,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
target_session_attrs, krbsrvname, gsslib,
service):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -754,7 +872,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, service=service)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down
6 changes: 6 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,7 @@ async def _do_execute(
async def connect(dsn=None, *,
host=None, port=None,
user=None, password=None, passfile=None,
service=None,
database=None,
loop=None,
timeout=60,
Expand Down Expand Up @@ -2183,6 +2184,10 @@ async def connect(dsn=None, *,
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
on Windows).

:param service:
The name of the postgres connection service stored in the postgres
connection service file.

:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
Expand Down Expand Up @@ -2428,6 +2433,7 @@ async def connect(dsn=None, *,
user=user,
password=password,
passfile=passfile,
service=service,
ssl=ssl,
direct_tls=direct_tls,
database=database,
Expand Down
71 changes: 69 additions & 2 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,8 @@ def run_testcase(self, testcase):
env = testcase.get('env', {})
test_env = {'PGHOST': None, 'PGPORT': None,
'PGUSER': None, 'PGPASSWORD': None,
'PGDATABASE': None, 'PGSSLMODE': None}
'PGDATABASE': None, 'PGSSLMODE': None,
'PGSERVICE': None, }
test_env.update(env)

dsn = testcase.get('dsn')
Expand All @@ -1132,6 +1133,7 @@ def run_testcase(self, testcase):
target_session_attrs = testcase.get('target_session_attrs')
krbsrvname = testcase.get('krbsrvname')
gsslib = testcase.get('gsslib')
service = testcase.get('service')

expected = testcase.get('result')
expected_error = testcase.get('error')
Expand All @@ -1157,7 +1159,7 @@ def run_testcase(self, testcase):
direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, service=service)

params = {
k: v for k, v in params._asdict().items()
Expand Down Expand Up @@ -1236,6 +1238,71 @@ def test_connect_params(self):
for testcase in self.TESTS:
self.run_testcase(testcase)

def test_connect_connection_service_file(self):
connection_service_file = tempfile.NamedTemporaryFile(
'w+t', delete=False)
connection_service_file.write(textwrap.dedent('''
[test_service_dbname]
port=5433
host=somehost
dbname=test_dbname
user=admin
password=test_password
target_session_attrs=primary
krbsrvname=fakekrbsrvname
gsslib=sspi

[test_service_database]
port=5433
host=somehost
database=test_dbname
user=admin
password=test_password
target_session_attrs=primary
krbsrvname=fakekrbsrvname
gsslib=sspi
'''))
connection_service_file.close()
os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR)
try:
# passfile path in env
self.run_testcase({
'dsn': 'postgresql://?service=test_service_dbname',
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
self.run_testcase({
'dsn': 'postgresql://?service=test_service_database',
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
finally:
os.unlink(connection_service_file.name)

def test_connect_pgpass_regular(self):
passfile = tempfile.NamedTemporaryFile('w+t', delete=False)
passfile.write(textwrap.dedent(R'''
Expand Down
Loading