Skip to content

Commit b89d277

Browse files
authored
Add Custom SQL Query Support with @query Decorator (#7)
1 parent b48e405 commit b89d277

File tree

6 files changed

+337
-76
lines changed

6 files changed

+337
-76
lines changed

README.md

Lines changed: 202 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Features
1010
- Automatic CRUD Repository: PySpringModel automatically generates a CRUD repository for each of your SQLModel entities, providing common database operations such as Create, Read, Update, and Delete.
1111
- Managed Sessions: PySpringModel provides a context manager for database sessions, automatically handling session commit and rollback to ensure data consistency.
1212
- Dynamic Query Generation: PySpringModel can dynamically generate and execute SQL queries based on method names in your repositories.
13+
- Custom SQL Queries: PySpringModel supports custom SQL queries using the `@Query` decorator for complex database operations.
1314
- RESTful API Integration: PySpringModel integrates with the PySpring framework to automatically generate basic table CRUD APIs for your SQLModel entities.
1415

1516
Installation
@@ -29,31 +30,51 @@ from sqlmodel import Field
2930
class User(PySpringModel, table=True):
3031
id: int = Field(default=None, primary_key=True)
3132
name: str = Field()
32-
email: str = Field()`
33+
email: str = Field()
34+
age: int = Field()
35+
status: str = Field()
3336
```
3437

35-
1. Define a repository for your model by subclassing `CrudRepository`:
38+
2. Define a repository for your model by subclassing `CrudRepository`:
3639

3740
```py
38-
from py_spring_model import CrudRepository
41+
from py_spring_model import CrudRepository, Query
42+
from typing import Optional, List
3943

4044
class UserRepository(CrudRepository[int, User]):
41-
# Implementation will be auto generated based on method name
42-
def find_by_name(self, name: str) -> User: ...
43-
def find_by_name_and_email(self, name: str, email: str) -> User: ...
45+
# Dynamic method-based queries (auto-implemented)
46+
def find_by_name(self, name: str) -> Optional[User]: ...
47+
def find_by_email(self, email: str) -> Optional[User]: ...
48+
def find_by_name_and_email(self, name: str, email: str) -> Optional[User]: ...
49+
def find_by_name_or_email(self, name: str, email: str) -> Optional[User]: ...
50+
def find_all_by_status(self, status: str) -> List[User]: ...
51+
def find_all_by_age_and_status(self, age: int, status: str) -> List[User]: ...
4452

53+
# Custom SQL queries using @Query decorator
54+
@Query("SELECT * FROM user WHERE age > {min_age}")
55+
def find_users_older_than(self, min_age: int) -> List[User]: ...
56+
57+
@Query("SELECT * FROM user WHERE email LIKE '%{domain}%'")
58+
def find_users_by_email_domain(self, domain: str) -> List[User]: ...
59+
60+
@Query("SELECT * FROM user WHERE age BETWEEN {min_age} AND {max_age}")
61+
def find_users_by_age_range(self, min_age: int, max_age: int) -> List[User]: ...
4562
```
4663

47-
1. Use your repository in your service or controller:
64+
3. Use your repository in your service or controller:
4865

4966
```py
5067
class UserService:
5168
user_repository: UserRepository
5269

53-
def get_user_by_name(self, name: str) -> User:
70+
def get_user_by_name(self, name: str) -> Optional[User]:
5471
return self.user_repository.find_by_name(name)
72+
73+
def get_active_users_older_than(self, min_age: int) -> List[User]:
74+
return self.user_repository.find_users_older_than(min_age)
5575
```
56-
1. Run your application with `PySpringApplication`:
76+
77+
4. Run your application with `PySpringApplication`:
5778

5879
```py
5980
from py_spring_core import PySpringApplication
@@ -63,4 +84,175 @@ PySpringApplication(
6384
"./app-config.json",
6485
entity_providers=[provide_py_spring_model()]
6586
).run()
66-
```
87+
```
88+
89+
Query Examples
90+
--------------
91+
92+
### Dynamic Method-Based Queries
93+
94+
PySpringModel automatically implements query methods based on their names. The method names follow a specific pattern:
95+
96+
#### Single Result Queries (returns Optional[Model])
97+
```py
98+
# Find by single field
99+
def find_by_name(self, name: str) -> Optional[User]: ...
100+
def get_by_email(self, email: str) -> Optional[User]: ...
101+
102+
# Find by multiple fields with AND condition
103+
def find_by_name_and_email(self, name: str, email: str) -> Optional[User]: ...
104+
def get_by_age_and_status(self, age: int, status: str) -> Optional[User]: ...
105+
106+
# Find by multiple fields with OR condition
107+
def find_by_name_or_email(self, name: str, email: str) -> Optional[User]: ...
108+
def get_by_status_or_age(self, status: str, age: int) -> Optional[User]: ...
109+
```
110+
111+
#### Multiple Result Queries (returns List[Model])
112+
```py
113+
# Find all by single field
114+
def find_all_by_status(self, status: str) -> List[User]: ...
115+
def get_all_by_age(self, age: int) -> List[User]: ...
116+
117+
# Find all by multiple fields with AND condition
118+
def find_all_by_age_and_status(self, age: int, status: str) -> List[User]: ...
119+
def get_all_by_name_and_email(self, name: str, email: str) -> List[User]: ...
120+
121+
# Find all by multiple fields with OR condition
122+
def find_all_by_status_or_age(self, status: str, age: int) -> List[User]: ...
123+
def get_all_by_name_or_email(self, name: str, email: str) -> List[User]: ...
124+
```
125+
126+
### Custom SQL Queries
127+
128+
For complex queries that can't be expressed through method names, use the `@Query` decorator:
129+
130+
#### Basic Custom Queries
131+
```py
132+
@Query("SELECT * FROM user WHERE age > {min_age}")
133+
def find_users_older_than(self, min_age: int) -> List[User]: ...
134+
135+
@Query("SELECT * FROM user WHERE age < {max_age}")
136+
def find_users_younger_than(self, max_age: int) -> List[User]: ...
137+
138+
@Query("SELECT * FROM user WHERE email LIKE '%{domain}%'")
139+
def find_users_by_email_domain(self, domain: str) -> List[User]: ...
140+
```
141+
142+
#### Complex Custom Queries
143+
```py
144+
@Query("SELECT * FROM user WHERE age BETWEEN {min_age} AND {max_age} AND status = {status}")
145+
def find_users_by_age_range_and_status(self, min_age: int, max_age: int, status: str) -> List[User]: ...
146+
147+
@Query("SELECT * FROM user WHERE name LIKE %{name_pattern}% OR email LIKE %{email_pattern}%")
148+
def search_users_by_name_or_email(self, name_pattern: str, email_pattern: str) -> List[User]: ...
149+
150+
@Query("SELECT * FROM user ORDER BY age DESC LIMIT {limit}")
151+
def find_oldest_users(self, limit: int) -> List[User]: ...
152+
```
153+
154+
#### Single Result Custom Queries
155+
```py
156+
@Query("SELECT * FROM user WHERE email = {email} LIMIT 1")
157+
def get_user_by_email(self, email: str) -> Optional[User]: ...
158+
159+
@Query("SELECT * FROM user WHERE name = {name} AND status = {status} LIMIT 1")
160+
def get_user_by_name_and_status(self, name: str, status: str) -> Optional[User]: ...
161+
```
162+
163+
### Built-in CRUD Operations
164+
165+
The `CrudRepository` provides these built-in methods:
166+
167+
```py
168+
# Read operations
169+
user_repository.find_by_id(1) # Find by primary key
170+
user_repository.find_all_by_ids([1, 2, 3]) # Find multiple by IDs
171+
user_repository.find_all() # Find all records
172+
173+
# Write operations
174+
user_repository.save(user) # Save single entity
175+
user_repository.save_all([user1, user2]) # Save multiple entities
176+
user_repository.upsert(user, {"email": "..."}) # Insert or update
177+
178+
# Delete operations
179+
user_repository.delete(user) # Delete single entity
180+
user_repository.delete_by_id(1) # Delete by ID
181+
user_repository.delete_all([user1, user2]) # Delete multiple entities
182+
user_repository.delete_all_by_ids([1, 2, 3]) # Delete multiple by IDs
183+
```
184+
185+
### Complete Example
186+
187+
Here's a complete example showing all query types:
188+
189+
```py
190+
from py_spring_model import PySpringModel, CrudRepository, Query
191+
from sqlmodel import Field
192+
from typing import Optional, List
193+
194+
# Model definition
195+
class User(PySpringModel, table=True):
196+
id: int = Field(default=None, primary_key=True)
197+
name: str = Field()
198+
email: str = Field()
199+
age: int = Field()
200+
status: str = Field()
201+
202+
# Repository with all query types
203+
class UserRepository(CrudRepository[int, User]):
204+
# Dynamic queries
205+
def find_by_name(self, name: str) -> Optional[User]: ...
206+
def find_by_email(self, email: str) -> Optional[User]: ...
207+
def find_by_name_and_status(self, name: str, status: str) -> Optional[User]: ...
208+
def find_all_by_status(self, status: str) -> List[User]: ...
209+
def find_all_by_age_and_status(self, age: int, status: str) -> List[User]: ...
210+
211+
# Custom SQL queries
212+
@Query("SELECT * FROM user WHERE age > {min_age}")
213+
def find_users_older_than(self, min_age: int) -> List[User]: ...
214+
215+
@Query("SELECT * FROM user WHERE email LIKE '%{domain}%'")
216+
def find_users_by_email_domain(self, domain: str) -> List[User]: ...
217+
218+
@Query("SELECT * FROM user WHERE age BETWEEN {min_age} AND {max_age}")
219+
def find_users_by_age_range(self, min_age: int, max_age: int) -> List[User]: ...
220+
221+
# Usage example
222+
class UserService:
223+
user_repository: UserRepository
224+
225+
def get_user_by_name(self, name: str) -> Optional[User]:
226+
return self.user_repository.find_by_name(name)
227+
228+
def get_active_users_older_than(self, min_age: int) -> List[User]:
229+
return self.user_repository.find_users_older_than(min_age)
230+
231+
def get_users_by_email_domain(self, domain: str) -> List[User]:
232+
return self.user_repository.find_users_by_email_domain(domain)
233+
234+
def get_users_in_age_range(self, min_age: int, max_age: int) -> List[User]:
235+
return self.user_repository.find_users_by_age_range(min_age, max_age)
236+
```
237+
238+
### Method Naming Conventions
239+
240+
The dynamic query generation follows these naming conventions:
241+
242+
- **Prefixes**: `find_by_`, `get_by_`, `find_all_by_`, `get_all_by_`
243+
- **Single field**: `find_by_name``WHERE name = ?`
244+
- **Multiple fields with AND**: `find_by_name_and_email``WHERE name = ? AND email = ?`
245+
- **Multiple fields with OR**: `find_by_name_or_email``WHERE name = ? OR email = ?`
246+
- **Return types**:
247+
- `find_by_*` and `get_by_*` return `Optional[Model]`
248+
- `find_all_by_*` and `get_all_by_*` return `List[Model]`
249+
250+
### Query Decorator Features
251+
252+
The `@Query` decorator supports:
253+
254+
- **Parameter substitution**: Use `{parameter_name}` in SQL
255+
- **Type safety**: Method parameters must match SQL parameters
256+
- **Return type inference**: Automatically handles `Optional[Model]` and `List[Model]`
257+
- **Error handling**: Validates required parameters and types
258+
- **SQL injection protection**: Parameters are properly escaped

py_spring_model/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from py_spring_model.py_spring_model_provider import provide_py_spring_model
33
from py_spring_model.repository.crud_repository import CrudRepository
44
from py_spring_model.repository.repository_base import RepositoryBase
5-
from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import Query, SkipAutoImplmentation
5+
from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import SkipAutoImplmentation
6+
from py_spring_model.py_spring_model_rest.service.query_service.query import Query

py_spring_model/py_spring_model_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def provide_py_spring_model() -> EntityProvider:
151151
# PySpringModelRestController
152152
],
153153
component_classes=[
154-
# PySpringModelRestService,
155-
# CrudRepositoryImplementationService,
154+
PySpringModelRestService,
155+
CrudRepositoryImplementationService,
156156
],
157157
properties_classes=[
158158
PySpringModelProperties

py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py

Lines changed: 19 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
import copy
2-
import functools
3-
from collections.abc import Iterable
42
from typing import (
53
Any,
64
Callable,
75
ClassVar,
86
Type,
97
TypeVar,
108
Union,
11-
cast,
12-
get_args,
13-
get_origin,
149
ParamSpec
1510
)
1611

1712
from loguru import logger
1813
from py_spring_core import Component
1914
from pydantic import BaseModel
20-
from sqlalchemy import ColumnElement, text
15+
from enum import Enum
16+
from sqlalchemy import ColumnElement
2117
from sqlalchemy.sql import and_, or_
2218
from sqlmodel import select
2319
from sqlmodel.sql.expression import SelectOfScalar
@@ -31,6 +27,11 @@
3127

3228
PySpringModelT = TypeVar("PySpringModelT", bound=PySpringModel)
3329

30+
31+
class ConditionNotation(str,Enum):
32+
AND = "_and_"
33+
OR = "_or_"
34+
3435
class CrudRepositoryImplementationService(Component):
3536
"""
3637
The `CrudRepositoryImplementationService` class is responsible for implementing the query logic for the `CrudRepository` inheritors.
@@ -139,9 +140,9 @@ def _get_sql_statement(
139140
right_condition = filter_condition_stack.pop(0)
140141
left_condition = filter_condition_stack.pop(0)
141142
match notation:
142-
case "_and_":
143+
case ConditionNotation.AND:
143144
filter_condition_stack.append(and_(left_condition, right_condition))
144-
case "_or_":
145+
case ConditionNotation.OR:
145146
filter_condition_stack.append(or_(left_condition, right_condition))
146147

147148
query = select(model_type)
@@ -167,62 +168,19 @@ def post_construct(self) -> None:
167168
T = TypeVar("T", bound=BaseModel)
168169
RT = TypeVar("RT", bound=Union[T, None, list[T]]) # type: ignore
169170

170-
def Query(query_template: str) -> Callable[[Callable[P, RT]], Callable[P, RT]]:
171-
def decorator(func: Callable[P, RT]) -> Callable[P, RT]:
172-
func_full_name = func.__qualname__
173-
CrudRepositoryImplementationService.add_skip_function(func_full_name)
174-
@functools.wraps(func)
175-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> RT:
176-
nonlocal query_template
177-
RETURN = "return"
178-
179-
annotations = func.__annotations__
180-
if RETURN not in annotations:
181-
raise ValueError(f"Missing return annotation for function: {func.__name__}")
182-
183-
return_type = annotations[RETURN]
184-
for key, value_type in annotations.items():
185-
if key == RETURN:
186-
continue
187-
if key not in kwargs or kwargs[key] is None:
188-
raise ValueError(f"Missing required argument: {key}")
189-
if value_type != type(kwargs[key]):
190-
raise TypeError(f"Invalid type for argument {key}. Expected {value_type}, got {type(kwargs[key])}")
191-
192-
sql = query_template.format(**kwargs)
193-
with PySpringModel.create_session() as session: # Replace with your actual session mechanism
194-
reutrn_origin = get_origin(return_type)
195-
return_args = get_args(return_type)
196-
197-
198-
# Handle None or list[T]
199-
if type(None) in return_args:
200-
actual_type = [arg for arg in return_args if arg is not type(None)].pop()
201-
else:
202-
if len(return_args) != 0:
203-
actual_type = return_args[0]
204-
else:
205-
actual_type = return_type
206-
207-
if reutrn_origin in {list, Iterable} and return_args:
208-
if not issubclass(actual_type, BaseModel):
209-
raise ValueError(f"Invalid return type: {return_type}, expected Iterable[BaseModel]")
210-
211-
result = session.execute(text(sql)).fetchall()
212-
return cast(RT, [actual_type.model_validate(row._asdict()) for row in result])
213-
214-
elif issubclass(actual_type, BaseModel):
215-
result = session.execute(text(sql)).first()
216-
if result is None:
217-
return cast(RT, None)
218-
return cast(RT, actual_type.model_validate(result._asdict()))
219-
else:
220-
raise ValueError(f"Invalid return type: {actual_type}")
221-
return wrapper
222-
return decorator
223171

224172

225173
def SkipAutoImplmentation(func: Callable[P, RT]) -> Callable[P, RT]:
174+
"""
175+
Decorator to skip the auto implementation for a method.
176+
The method will not be implemented automatically by the `CrudRepositoryImplementationService`.
177+
The method should have the following signature:
178+
```python
179+
@SkipAutoImplmentation
180+
def get_user_by_email(self, email: str) -> Optional[UserRead]:
181+
...
182+
```
183+
"""
226184
func_name = func.__qualname__
227185
logger.info(f"Skipping auto implementation for function: {func_name}")
228186
CrudRepositoryImplementationService.add_skip_function(func_name)

0 commit comments

Comments
 (0)