Skip to content

Commit f4cef37

Browse files
authored
PYTHON-3064 Add typings to test package (mongodb#844)
1 parent 561ee7c commit f4cef37

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+542
-261
lines changed

.github/workflows/test-python.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,5 @@ jobs:
4646
pip install -e ".[zstd, srv]"
4747
- name: Run mypy
4848
run: |
49-
mypy --install-types --non-interactive bson gridfs tools
49+
mypy --install-types --non-interactive bson gridfs tools pymongo
50+
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test

bson/son.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
# This is essentially the same as re._pattern_type
2929
RE_TYPE: Type[Pattern[Any]] = type(re.compile(""))
3030

31-
_Key = TypeVar("_Key", bound=str)
31+
_Key = TypeVar("_Key")
3232
_Value = TypeVar("_Value")
3333
_T = TypeVar("_T")
3434

mypy.ini

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ warn_unused_configs = true
1111
warn_unused_ignores = true
1212
warn_redundant_casts = true
1313

14+
[mypy-gevent.*]
15+
ignore_missing_imports = True
16+
1417
[mypy-kerberos.*]
1518
ignore_missing_imports = True
1619

@@ -29,5 +32,12 @@ ignore_missing_imports = True
2932
[mypy-snappy.*]
3033
ignore_missing_imports = True
3134

35+
[mypy-test.*]
36+
allow_redefinition = true
37+
allow_untyped_globals = true
38+
3239
[mypy-winkerberos.*]
3340
ignore_missing_imports = True
41+
42+
[mypy-xmlrunner.*]
43+
ignore_missing_imports = True

pymongo/socket_checker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616

1717
import errno
1818
import select
19-
import socket
2019
import sys
21-
from typing import Any, Optional
20+
from typing import Any, Optional, Union
2221

2322
# PYTHON-2320: Jython does not fully support poll on SSL sockets,
2423
# https://bugs.jython.org/issue2900
@@ -43,7 +42,7 @@ def __init__(self) -> None:
4342
else:
4443
self._poller = None
4544

46-
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
45+
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0) -> bool:
4746
"""Select for reads or writes with a timeout in seconds (or None).
4847
4948
Returns True if the socket is readable/writable, False on timeout.

pymongo/srv_resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def maybe_decode(text):
3939
def _resolve(*args, **kwargs):
4040
if hasattr(resolver, 'resolve'):
4141
# dnspython >= 2
42-
return resolver.resolve(*args, **kwargs) # type: ignore
42+
return resolver.resolve(*args, **kwargs)
4343
# dnspython 1.X
4444
return resolver.query(*args, **kwargs)
4545

pymongo/typings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Type aliases used by PyMongo"""
1616
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
17-
Tuple, Type, TypeVar, Union)
17+
Sequence, Tuple, Type, TypeVar, Union)
1818

1919
if TYPE_CHECKING:
2020
from bson.raw_bson import RawBSONDocument
@@ -25,5 +25,5 @@
2525
_Address = Tuple[str, Optional[int]]
2626
_CollationIn = Union[Mapping[str, Any], "Collation"]
2727
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
28-
_Pipeline = List[Mapping[str, Any]]
28+
_Pipeline = Sequence[Mapping[str, Any]]
2929
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])

test/__init__.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from contextlib import contextmanager
4242
from functools import wraps
43+
from typing import Dict, no_type_check
4344
from unittest import SkipTest
4445

4546
import pymongo
@@ -48,7 +49,9 @@
4849
from bson.son import SON
4950
from pymongo import common, message
5051
from pymongo.common import partition_node
52+
from pymongo.database import Database
5153
from pymongo.hello import HelloCompat
54+
from pymongo.mongo_client import MongoClient
5255
from pymongo.server_api import ServerApi
5356
from pymongo.ssl_support import HAVE_SSL, _ssl
5457
from pymongo.uri_parser import parse_uri
@@ -86,7 +89,7 @@
8689
os.path.join(CERT_PATH, 'client.pem'))
8790
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))
8891

89-
TLS_OPTIONS = dict(tls=True)
92+
TLS_OPTIONS: Dict = dict(tls=True)
9093
if CLIENT_PEM:
9194
TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
9295
if CA_PEM:
@@ -102,13 +105,13 @@
102105
# Remove after PYTHON-2712
103106
from pymongo import pool
104107
pool._MOCK_SERVICE_ID = True
105-
res = parse_uri(SINGLE_MONGOS_LB_URI)
108+
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
106109
host, port = res['nodelist'][0]
107110
db_user = res['username'] or db_user
108111
db_pwd = res['password'] or db_pwd
109112
elif TEST_SERVERLESS:
110113
TEST_LOADBALANCER = True
111-
res = parse_uri(SINGLE_MONGOS_LB_URI)
114+
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
112115
host, port = res['nodelist'][0]
113116
db_user = res['username'] or db_user
114117
db_pwd = res['password'] or db_pwd
@@ -184,6 +187,7 @@ def enable(self):
184187
def __enter__(self):
185188
self.enable()
186189

190+
@no_type_check
187191
def disable(self):
188192
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
189193
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
@@ -224,6 +228,8 @@ def _all_users(db):
224228

225229

226230
class ClientContext(object):
231+
client: MongoClient
232+
227233
MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI
228234

229235
def __init__(self):
@@ -247,9 +253,9 @@ def __init__(self):
247253
self.tls = False
248254
self.tlsCertificateKeyFile = False
249255
self.server_is_resolvable = is_server_resolvable()
250-
self.default_client_options = {}
256+
self.default_client_options: Dict = {}
251257
self.sessions_enabled = False
252-
self.client = None
258+
self.client = None # type: ignore
253259
self.conn_lock = threading.Lock()
254260
self.is_data_lake = False
255261
self.load_balancer = TEST_LOADBALANCER
@@ -340,6 +346,7 @@ def _init_client(self):
340346
try:
341347
self.cmd_line = self.client.admin.command('getCmdLineOpts')
342348
except pymongo.errors.OperationFailure as e:
349+
assert e.details is not None
343350
msg = e.details.get('errmsg', '')
344351
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
345352
# Unauthorized.
@@ -418,6 +425,7 @@ def _init_client(self):
418425
else:
419426
self.server_parameters = self.client.admin.command(
420427
'getParameter', '*')
428+
assert self.cmd_line is not None
421429
if 'enableTestCommands=1' in self.cmd_line['argv']:
422430
self.test_commands_enabled = True
423431
elif 'parsed' in self.cmd_line:
@@ -436,7 +444,8 @@ def _init_client(self):
436444
self.mongoses.append(address)
437445
if not self.serverless:
438446
# Check for another mongos on the next port.
439-
next_address = address[0], address[1] + 1
447+
assert address is not None
448+
next_address = address[0], address[1] + 1
440449
mongos_client = self._connect(
441450
*next_address, **self.default_client_options)
442451
if mongos_client:
@@ -496,6 +505,7 @@ def _check_user_provided(self):
496505
try:
497506
return db_user in _all_users(client.admin)
498507
except pymongo.errors.OperationFailure as e:
508+
assert e.details is not None
499509
msg = e.details.get('errmsg', '')
500510
if e.code == 18 or 'auth fails' in msg:
501511
# Auth failed.
@@ -505,6 +515,7 @@ def _check_user_provided(self):
505515

506516
def _server_started_with_auth(self):
507517
# MongoDB >= 2.0
518+
assert self.cmd_line is not None
508519
if 'parsed' in self.cmd_line:
509520
parsed = self.cmd_line['parsed']
510521
# MongoDB >= 2.6
@@ -525,6 +536,7 @@ def _server_started_with_ipv6(self):
525536
if not socket.has_ipv6:
526537
return False
527538

539+
assert self.cmd_line is not None
528540
if 'parsed' in self.cmd_line:
529541
if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
530542
return False
@@ -932,6 +944,9 @@ def fail_point(self, command_args):
932944

933945
class IntegrationTest(PyMongoTestCase):
934946
"""Base class for TestCases that need a connection to MongoDB to pass."""
947+
client: MongoClient
948+
db: Database
949+
credentials: Dict[str, str]
935950

936951
@classmethod
937952
@client_context.require_connection
@@ -1073,7 +1088,7 @@ def run(self, test):
10731088

10741089

10751090
if HAVE_XML:
1076-
class PymongoXMLTestRunner(XMLTestRunner):
1091+
class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc]
10771092
def run(self, test):
10781093
setup()
10791094
result = super(PymongoXMLTestRunner, self).run(test)

test/auth_aws/test_auth_aws.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
class TestAuthAWS(unittest.TestCase):
29+
uri: str
2930

3031
@classmethod
3132
def setUpClass(cls):

test/mockupdb/test_cursor_namespace.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222

2323
class TestCursorNamespace(unittest.TestCase):
24+
server: MockupDB
25+
client: MongoClient
26+
2427
@classmethod
2528
def setUpClass(cls):
2629
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
@@ -69,6 +72,9 @@ def op():
6972

7073

7174
class TestKillCursorsNamespace(unittest.TestCase):
75+
server: MockupDB
76+
client: MongoClient
77+
7278
@classmethod
7379
def setUpClass(cls):
7480
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})

test/mockupdb/test_getmore_sharded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_getmore_sharded(self):
2727
servers = [MockupDB(), MockupDB()]
2828

2929
# Collect queries to either server in one queue.
30-
q = Queue()
30+
q: Queue = Queue()
3131
for server in servers:
3232
server.subscribe(q.put)
3333
server.autoresponds('ismaster', ismaster=True, msg='isdbgrid',

0 commit comments

Comments
 (0)