Skip to content

Commit d06f645

Browse files
committed
sql30: Allow range based queries.
Change here allows constraints to be based on ranges when performing various queries. Change is central so that all CRUD operations and miscellaneous operations are being able to consume it.
1 parent 58895a3 commit d06f645

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

sql30/db.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,14 @@ def _get_fields(self, tbl_name):
219219
return [key for key, _ in _schema['fields'].items()]
220220

221221
def _form_constraints(self, _separator='and', kwargs=None):
222-
constraints = ['%s=:%s' % (key, key) for key, _ in kwargs.items()]
222+
def is_range(x): return isinstance(x, tuple) or isinstance(x, list)
223+
constraints = []
224+
for key, val in kwargs.items():
225+
if is_range(val):
226+
cparam = '%s BETWEEN %s AND %s' % (key, val[0], val[1])
227+
else:
228+
cparam = '%s=:%s' % (key, key)
229+
constraints.append(cparam)
223230
return (' %s ' % _separator).join(constraints)
224231

225232
def _validate_bfr_write(self, tbl, kwargs):

sql30/tests/test_misc.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,53 +37,95 @@ def setUp(self):
3737
self.db = Square()
3838
self.db.table = self.TABLE
3939

40-
# add 3 records
40+
# Populate 3 records for validating queries later.
4141
self.db.write(num=1, square=1)
4242
self.db.write(num=2, square=4)
4343
self.db.write(num=3, square=9)
4444
self.db.commit()
4545

4646
def test_count(self):
4747
"""
48-
Tests for context manager operations.
48+
Tests for COUNT operations.
4949
"""
5050
with Square() as db:
5151
db.table = self.TABLE
52+
# case 1: Total numnber of records should be 3.
5253
self.assertEqual(db.count(), 3)
5354

55+
# Number of records where num is 2 should be 1.
5456
self.assertEqual(db.count(num=2), 1)
57+
58+
# Number of records where square is 9 should be 1.
5559
self.assertEqual(db.count(square=9), 1)
5660

61+
# Number of records where 4 <= square <= 9 should be 1.
62+
self.assertEqual(db.count(square=[4, 9]), 2)
63+
5764
def test_min(self):
5865
"""
59-
Tests for context manager operations.
66+
Tests for MIN operations.
6067
"""
6168
with Square() as db:
6269
db.table = self.TABLE
70+
# Minimum value for num is 1
6371
self.assertEqual(db.min('num'), 1)
72+
# Minimum value for square is 1
6473
self.assertEqual(db.min('square'), 1)
6574

75+
# Minimum value for square when num = 2, is 4
6676
self.assertEqual(db.min('square', num=2), 4)
77+
# Minimum value for square when 3 <= num <= 5, is 9
78+
self.assertEqual(db.min('square', num=[3, 5]), 9)
6779

6880
def test_max(self):
6981
"""
70-
Tests for context manager operations.
82+
Tests for MAX operations.
7183
"""
7284
with Square() as db:
7385
db.table = self.TABLE
86+
# Maximum value for num is 3
7487
self.assertEqual(db.max('num'), 3)
88+
# Maximum value for square is 9
7589
self.assertEqual(db.max('square'), 9)
7690

91+
# Maximum value for square when num = 2, is 4
7792
self.assertEqual(db.max('square', num=2), 4)
7893

94+
# Minimum value for square when 0 <= num <= 5, is 9
95+
self.assertEqual(db.max('square', num=[0, 5]), 9)
96+
7997
def test_avg(self):
8098
"""
81-
Tests for context manager operations.
99+
Tests for AVERAGE operations.
82100
"""
83101
with Square() as db:
84102
db.table = self.TABLE
103+
# Average of all num values is 2.
85104
self.assertEqual(db.avg('num'), 2)
86105
self.assertEqual(db.avg('square', num=3), 9)
87106

107+
# Average value of square when 1 <= num <= 2 should be 2.5
108+
self.assertEqual(db.avg('square', num=(1, 2)), 2.5)
109+
110+
def test_range(self):
111+
"""
112+
Tests for RANGE operations.
113+
"""
114+
with Square() as db:
115+
db.table = self.TABLE
116+
# Read records where : 0 <= num <= 3. There should be 3 records.
117+
self.assertEqual(len(db.read(num=(0, 3))), 3)
118+
119+
# Read records where : 1 <= num <= 2. There should be 2 records.
120+
self.assertEqual(len(db.read(num=(1, 2))), 2)
121+
122+
# Read records where : 2 < quare < 5. There should be 1 record.
123+
# Also check that list based range is expected
124+
self.assertEqual(len(db.read(square=[2, 5])), 1)
125+
126+
# Check that range is inclusive i.e. no records for :
127+
# 2 <= square <= 3.
128+
self.assertEqual(len(db.read(square=[2, 3])), 0)
129+
88130
def tearDown(self):
89131
os.remove(DB_NAME)

0 commit comments

Comments
 (0)