|  | 
| 18 | 18 | import ssl as ssl_module | 
| 19 | 19 | import stat | 
| 20 | 20 | import struct | 
|  | 21 | +import sys | 
| 21 | 22 | import time | 
| 22 | 23 | import typing | 
| 23 | 24 | import urllib.parse | 
| @@ -220,13 +221,35 @@ def _parse_hostlist(hostlist, port, *, unquote=False): | 
| 220 | 221 |  return hosts, port | 
| 221 | 222 | 
 | 
| 222 | 223 | 
 | 
|  | 224 | +def _parse_tls_version(tls_version): | 
|  | 225 | + if not hasattr(ssl_module, 'TLSVersion'): | 
|  | 226 | + raise ValueError( | 
|  | 227 | + "TLSVersion is not supported in this version of Python" | 
|  | 228 | + ) | 
|  | 229 | + if tls_version.startswith('SSL'): | 
|  | 230 | + raise ValueError( | 
|  | 231 | + f"Unsupported TLS version: {tls_version}" | 
|  | 232 | + ) | 
|  | 233 | + try: | 
|  | 234 | + return ssl_module.TLSVersion[tls_version.replace('.', '_')] | 
|  | 235 | + except KeyError: | 
|  | 236 | + raise ValueError( | 
|  | 237 | + f"No such TLS version: {tls_version}" | 
|  | 238 | + ) | 
|  | 239 | + | 
|  | 240 | + | 
|  | 241 | +def _dot_postgresql_path(filename) -> pathlib.Path: | 
|  | 242 | + return (pathlib.Path.home() / '.postgresql' / filename).resolve() | 
|  | 243 | + | 
|  | 244 | + | 
| 223 | 245 | def _parse_connect_dsn_and_args(*, dsn, host, port, user, | 
| 224 | 246 |  password, passfile, database, ssl, | 
| 225 | 247 |  connect_timeout, server_settings): | 
| 226 | 248 |  # `auth_hosts` is the version of host information for the purposes | 
| 227 | 249 |  # of reading the pgpass file. | 
| 228 | 250 |  auth_hosts = None | 
| 229 |  | - sslcert = sslkey = sslrootcert = sslcrl = None | 
|  | 251 | + sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None | 
|  | 252 | + ssl_min_protocol_version = ssl_max_protocol_version = None | 
| 230 | 253 | 
 | 
| 231 | 254 |  if dsn: | 
| 232 | 255 |  parsed = urllib.parse.urlparse(dsn) | 
| @@ -312,24 +335,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, | 
| 312 | 335 |  ssl = val | 
| 313 | 336 | 
 | 
| 314 | 337 |  if 'sslcert' in query: | 
| 315 |  | - val = query.pop('sslcert') | 
| 316 |  | - if sslcert is None: | 
| 317 |  | - sslcert = val | 
|  | 338 | + sslcert = query.pop('sslcert') | 
| 318 | 339 | 
 | 
| 319 | 340 |  if 'sslkey' in query: | 
| 320 |  | - val = query.pop('sslkey') | 
| 321 |  | - if sslkey is None: | 
| 322 |  | - sslkey = val | 
|  | 341 | + sslkey = query.pop('sslkey') | 
| 323 | 342 | 
 | 
| 324 | 343 |  if 'sslrootcert' in query: | 
| 325 |  | - val = query.pop('sslrootcert') | 
| 326 |  | - if sslrootcert is None: | 
| 327 |  | - sslrootcert = val | 
|  | 344 | + sslrootcert = query.pop('sslrootcert') | 
| 328 | 345 | 
 | 
| 329 | 346 |  if 'sslcrl' in query: | 
| 330 |  | - val = query.pop('sslcrl') | 
| 331 |  | - if sslcrl is None: | 
| 332 |  | - sslcrl = val | 
|  | 347 | + sslcrl = query.pop('sslcrl') | 
|  | 348 | + | 
|  | 349 | + if 'sslpassword' in query: | 
|  | 350 | + sslpassword = query.pop('sslpassword') | 
|  | 351 | + | 
|  | 352 | + if 'ssl_min_protocol_version' in query: | 
|  | 353 | + ssl_min_protocol_version = query.pop( | 
|  | 354 | + 'ssl_min_protocol_version' | 
|  | 355 | + ) | 
|  | 356 | + | 
|  | 357 | + if 'ssl_max_protocol_version' in query: | 
|  | 358 | + ssl_max_protocol_version = query.pop( | 
|  | 359 | + 'ssl_max_protocol_version' | 
|  | 360 | + ) | 
| 333 | 361 | 
 | 
| 334 | 362 |  if query: | 
| 335 | 363 |  if server_settings is None: | 
| @@ -451,34 +479,97 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, | 
| 451 | 479 |  if sslmode < SSLMode.allow: | 
| 452 | 480 |  ssl = False | 
| 453 | 481 |  else: | 
| 454 |  | - ssl = ssl_module.create_default_context( | 
| 455 |  | - ssl_module.Purpose.SERVER_AUTH) | 
|  | 482 | + ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) | 
| 456 | 483 |  ssl.check_hostname = sslmode >= SSLMode.verify_full | 
| 457 |  | - ssl.verify_mode = ssl_module.CERT_REQUIRED | 
| 458 |  | - if sslmode <= SSLMode.require: | 
|  | 484 | + if sslmode < SSLMode.require: | 
| 459 | 485 |  ssl.verify_mode = ssl_module.CERT_NONE | 
|  | 486 | + else: | 
|  | 487 | + if sslrootcert is None: | 
|  | 488 | + sslrootcert = os.getenv('PGSSLROOTCERT') | 
|  | 489 | + if sslrootcert: | 
|  | 490 | + ssl.load_verify_locations(cafile=sslrootcert) | 
|  | 491 | + ssl.verify_mode = ssl_module.CERT_REQUIRED | 
|  | 492 | + else: | 
|  | 493 | + sslrootcert = _dot_postgresql_path('root.crt') | 
|  | 494 | + try: | 
|  | 495 | + ssl.load_verify_locations(cafile=sslrootcert) | 
|  | 496 | + except FileNotFoundError: | 
|  | 497 | + if sslmode > SSLMode.require: | 
|  | 498 | + raise ValueError( | 
|  | 499 | + f'root certificate file "{sslrootcert}" does ' | 
|  | 500 | + f'not exist\nEither provide the file or ' | 
|  | 501 | + f'change sslmode to disable server ' | 
|  | 502 | + f'certificate verification.' | 
|  | 503 | + ) | 
|  | 504 | + elif sslmode == SSLMode.require: | 
|  | 505 | + ssl.verify_mode = ssl_module.CERT_NONE | 
|  | 506 | + else: | 
|  | 507 | + assert False, 'unreachable' | 
|  | 508 | + else: | 
|  | 509 | + ssl.verify_mode = ssl_module.CERT_REQUIRED | 
| 460 | 510 | 
 | 
| 461 |  | - if sslcert is None: | 
| 462 |  | - sslcert = os.getenv('PGSSLCERT') | 
|  | 511 | + if sslcrl is None: | 
|  | 512 | + sslcrl = os.getenv('PGSSLCRL') | 
|  | 513 | + if sslcrl: | 
|  | 514 | + ssl.load_verify_locations(cafile=sslcrl) | 
|  | 515 | + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN | 
|  | 516 | + else: | 
|  | 517 | + sslcrl = _dot_postgresql_path('root.crl') | 
|  | 518 | + try: | 
|  | 519 | + ssl.load_verify_locations(cafile=sslcrl) | 
|  | 520 | + except FileNotFoundError: | 
|  | 521 | + pass | 
|  | 522 | + else: | 
|  | 523 | + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN | 
| 463 | 524 | 
 | 
| 464 | 525 |  if sslkey is None: | 
| 465 | 526 |  sslkey = os.getenv('PGSSLKEY') | 
| 466 |  | - | 
| 467 |  | - if sslrootcert is None: | 
| 468 |  | - sslrootcert = os.getenv('PGSSLROOTCERT') | 
| 469 |  | - | 
| 470 |  | - if sslcrl is None: | 
| 471 |  | - sslcrl = os.getenv('PGSSLCRL') | 
| 472 |  | - | 
|  | 527 | + if not sslkey: | 
|  | 528 | + sslkey = _dot_postgresql_path('postgresql.key') | 
|  | 529 | + if not sslkey.exists(): | 
|  | 530 | + sslkey = None | 
|  | 531 | + if not sslpassword: | 
|  | 532 | + sslpassword = '' | 
|  | 533 | + if sslcert is None: | 
|  | 534 | + sslcert = os.getenv('PGSSLCERT') | 
| 473 | 535 |  if sslcert: | 
| 474 |  | - ssl.load_cert_chain(sslcert, keyfile=sslkey) | 
| 475 |  | - | 
| 476 |  | - if sslrootcert: | 
| 477 |  | - ssl.load_verify_locations(cafile=sslrootcert) | 
| 478 |  | - | 
| 479 |  | - if sslcrl: | 
| 480 |  | - ssl.load_verify_locations(cafile=sslcrl) | 
| 481 |  | - ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN | 
|  | 536 | + ssl.load_cert_chain( | 
|  | 537 | + sslcert, keyfile=sslkey, password=lambda: sslpassword | 
|  | 538 | + ) | 
|  | 539 | + else: | 
|  | 540 | + sslcert = _dot_postgresql_path('postgresql.crt') | 
|  | 541 | + try: | 
|  | 542 | + ssl.load_cert_chain( | 
|  | 543 | + sslcert, keyfile=sslkey, password=lambda: sslpassword | 
|  | 544 | + ) | 
|  | 545 | + except FileNotFoundError: | 
|  | 546 | + pass | 
|  | 547 | + | 
|  | 548 | + # OpenSSL 1.1.1 keylog file, copied from create_default_context() | 
|  | 549 | + if hasattr(ssl, 'keylog_filename'): | 
|  | 550 | + keylogfile = os.environ.get('SSLKEYLOGFILE') | 
|  | 551 | + if keylogfile and not sys.flags.ignore_environment: | 
|  | 552 | + ssl.keylog_filename = keylogfile | 
|  | 553 | + | 
|  | 554 | + if ssl_min_protocol_version is None: | 
|  | 555 | + ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') | 
|  | 556 | + if ssl_min_protocol_version: | 
|  | 557 | + ssl.minimum_version = _parse_tls_version( | 
|  | 558 | + ssl_min_protocol_version | 
|  | 559 | + ) | 
|  | 560 | + else: | 
|  | 561 | + try: | 
|  | 562 | + ssl.minimum_version = _parse_tls_version('TLSv1.2') | 
|  | 563 | + except ValueError: | 
|  | 564 | + # Python 3.6 does not have ssl.TLSVersion | 
|  | 565 | + pass | 
|  | 566 | + | 
|  | 567 | + if ssl_max_protocol_version is None: | 
|  | 568 | + ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') | 
|  | 569 | + if ssl_max_protocol_version: | 
|  | 570 | + ssl.maximum_version = _parse_tls_version( | 
|  | 571 | + ssl_max_protocol_version | 
|  | 572 | + ) | 
| 482 | 573 | 
 | 
| 483 | 574 |  elif ssl is True: | 
| 484 | 575 |  ssl = ssl_module.create_default_context() | 
|  | 
0 commit comments