Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import shutil
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, Generator, List, Union
from unittest.mock import patch

import pytest
from pydantic import BaseModel
Expand All @@ -26,7 +28,7 @@ def clear_sqlmodel() -> Any:


@pytest.fixture()
def cov_tmp_path(tmp_path: Path):
def cov_tmp_path(tmp_path: Path) -> Generator[Path, None, None]:
yield tmp_path
for coverage_path in tmp_path.glob(".coverage*"):
coverage_destiny_path = top_level_path / coverage_path.name
Expand All @@ -53,8 +55,8 @@ def coverage_run(*, module: str, cwd: Union[str, Path]) -> subprocess.CompletedP
def get_testing_print_function(
calls: List[List[Union[str, Dict[str, Any]]]],
) -> Callable[..., Any]:
def new_print(*args):
data = []
def new_print(*args: Any) -> None:
data: List[Any] = []
for arg in args:
if isinstance(arg, BaseModel):
data.append(arg.model_dump())
Expand All @@ -71,6 +73,19 @@ def new_print(*args):
return new_print


@dataclass
class PrintMock:
calls: List[Any] = field(default_factory=list)


@pytest.fixture(name="print_mock")
def print_mock_fixture() -> Generator[PrintMock, None, None]:
print_mock = PrintMock()
new_print = get_testing_print_function(print_mock.calls)
with patch("builtins.print", new=new_print):
yield print_mock


needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2")
needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1")

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import importlib
from types import ModuleType
from typing import Any, Dict, List, Union
from unittest.mock import patch

import pytest
from sqlmodel import create_engine

from tests.conftest import get_testing_print_function
from tests.conftest import PrintMock, needs_py310


def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]) -> None:
assert calls[0] == ["Before interacting with the database"]
assert calls[1] == [
"Hero 1:",
Expand Down Expand Up @@ -133,29 +135,25 @@ def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
]


def test_tutorial_001():
from docs_src.tutorial.automatic_id_none_refresh import tutorial001 as mod
@pytest.fixture(
name="module",
params=[
"tutorial001",
"tutorial002",
pytest.param("tutorial001_py310", marks=needs_py310),
pytest.param("tutorial002_py310", marks=needs_py310),
],
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
module = importlib.import_module(
f"docs_src.tutorial.automatic_id_none_refresh.{request.param}"
)
module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url)

mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []
return module

new_print = get_testing_print_function(calls)

with patch("builtins.print", new=new_print):
mod.main()
check_calls(calls)


def test_tutorial_002():
from docs_src.tutorial.automatic_id_none_refresh import tutorial002 as mod

mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []

new_print = get_testing_print_function(calls)

with patch("builtins.print", new=new_print):
mod.main()
check_calls(calls)
def test_tutorial_001_tutorial_002(print_mock: PrintMock, module: ModuleType) -> None:
module.main()
check_calls(print_mock.calls)
Loading