Skip to content

Commit cbc7cc3

Browse files
authored
PYTHON-3073 Copy the unit tests from pymongo-stubs into pymongo (mongodb#859)
1 parent ddb6614 commit cbc7cc3

File tree

5 files changed

+149
-1
lines changed

5 files changed

+149
-1
lines changed

.github/workflows/test-python.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ jobs:
3737
mongodb-version: 4.4
3838
- name: Run tests
3939
run: |
40+
pip install mypy
4041
python setup.py test
4142
4243
mypytest:
@@ -59,4 +60,4 @@ jobs:
5960
- name: Run mypy
6061
run: |
6162
mypy --install-types --non-interactive bson gridfs tools pymongo
62-
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test
63+
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index --exclude "test/mypy_fails/*.*" test
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pymongo import MongoClient
2+
3+
client = MongoClient()
4+
client.test.test.insert_many(
5+
{"a": 1}
6+
) # error: Dict entry 0 has incompatible type "str": "int"; expected "Mapping[str, Any]": "int"

test/mypy_fails/insert_one_list.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pymongo import MongoClient
2+
3+
client = MongoClient()
4+
client.test.test.insert_one(
5+
[{}]
6+
) # error: Argument 1 to "insert_one" of "Collection" has incompatible type "List[Dict[<nothing>, <nothing>]]"; expected "Mapping[str, Any]"

test/test_bson.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,16 @@ def test_int64_pickling(self):
11171117
)
11181118
self.round_trip_pickle(i64, pickled_with_3)
11191119

1120+
def test_bson_encode_decode(self) -> None:
1121+
doc = {"_id": ObjectId()}
1122+
encoded = bson.encode(doc)
1123+
decoded = bson.decode(encoded)
1124+
encoded = bson.encode(decoded)
1125+
decoded = bson.decode(encoded)
1126+
# Documents returned from decode are mutable.
1127+
decoded["new_field"] = 1
1128+
self.assertTrue(decoded["_id"].generation_time)
1129+
11201130

11211131
if __name__ == "__main__":
11221132
unittest.main()

test/test_mypy.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2020-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test that each file in mypy_fails/ actually fails mypy, and test some
16+
sample client code that uses PyMongo typings."""
17+
18+
import os
19+
import sys
20+
import unittest
21+
from typing import Any, Dict, Iterable, List
22+
23+
try:
24+
from mypy import api
25+
except ImportError:
26+
api = None
27+
28+
from bson.son import SON
29+
from pymongo.collection import Collection
30+
from pymongo.errors import ServerSelectionTimeoutError
31+
from pymongo.mongo_client import MongoClient
32+
from pymongo.operations import InsertOne
33+
34+
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails")
35+
36+
37+
def get_tests() -> Iterable[str]:
38+
for dirpath, _, filenames in os.walk(TEST_PATH):
39+
for filename in filenames:
40+
yield os.path.join(dirpath, filename)
41+
42+
43+
class TestMypyFails(unittest.TestCase):
44+
def ensure_mypy_fails(self, filename: str) -> None:
45+
if api is None:
46+
raise unittest.SkipTest("Mypy is not installed")
47+
stdout, stderr, exit_status = api.run([filename])
48+
self.assertTrue(exit_status, msg=stdout)
49+
50+
def test_mypy_failures(self) -> None:
51+
for filename in get_tests():
52+
with self.subTest(filename=filename):
53+
self.ensure_mypy_fails(filename)
54+
55+
56+
class TestPymongo(unittest.TestCase):
57+
client: MongoClient
58+
coll: Collection
59+
60+
@classmethod
61+
def setUpClass(cls) -> None:
62+
cls.client = MongoClient(serverSelectionTimeoutMS=250, directConnection=False)
63+
cls.coll = cls.client.test.test
64+
try:
65+
cls.client.admin.command("ping")
66+
except ServerSelectionTimeoutError as exc:
67+
raise unittest.SkipTest(f"Could not connect to MongoDB: {exc}")
68+
69+
@classmethod
70+
def tearDownClass(cls) -> None:
71+
cls.client.close()
72+
73+
def test_insert_find(self) -> None:
74+
doc = {"my": "doc"}
75+
coll2 = self.client.test.test2
76+
result = self.coll.insert_one(doc)
77+
self.assertEqual(result.inserted_id, doc["_id"])
78+
retreived = self.coll.find_one({"_id": doc["_id"]})
79+
if retreived:
80+
# Documents returned from find are mutable.
81+
retreived["new_field"] = 1
82+
result2 = coll2.insert_one(retreived)
83+
self.assertEqual(result2.inserted_id, result.inserted_id)
84+
85+
def test_cursor_iterable(self) -> None:
86+
def to_list(iterable: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]:
87+
return list(iterable)
88+
89+
self.coll.insert_one({})
90+
cursor = self.coll.find()
91+
docs = to_list(cursor)
92+
self.assertTrue(docs)
93+
94+
def test_bulk_write(self) -> None:
95+
self.coll.insert_one({})
96+
requests = [InsertOne({})]
97+
result = self.coll.bulk_write(requests)
98+
self.assertTrue(result.acknowledged)
99+
100+
def test_aggregate_pipeline(self) -> None:
101+
coll3 = self.client.test.test3
102+
coll3.insert_many(
103+
[
104+
{"x": 1, "tags": ["dog", "cat"]},
105+
{"x": 2, "tags": ["cat"]},
106+
{"x": 2, "tags": ["mouse", "cat", "dog"]},
107+
{"x": 3, "tags": []},
108+
]
109+
)
110+
111+
class mydict(Dict[str, Any]):
112+
pass
113+
114+
result = coll3.aggregate(
115+
[
116+
mydict({"$unwind": "$tags"}),
117+
{"$group": {"_id": "$tags", "count": {"$sum": 1}}},
118+
{"$sort": SON([("count", -1), ("_id", -1)])},
119+
]
120+
)
121+
self.assertTrue(len(list(result)))
122+
123+
124+
if __name__ == "__main__":
125+
unittest.main()

0 commit comments

Comments
 (0)