Skip to content
8 changes: 7 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ jobs:
name: run tests
command: |
. venv/bin/activate
REDIS_PORT=6379 python test/test.py
REDIS_PORT=6379 python test/test.py

- run:
name: run query builder tests
command: |
. venv/bin/activate
python test/test.py

# no need for store_artifacts on nightly builds

Expand Down
105 changes: 76 additions & 29 deletions redisearch/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,58 @@ def __init__(self, fields, reducers):
self.limit = Limit()

def build_args(self):
ret = [str(len(self.fields))]
ret = ['GROUPBY', str(len(self.fields))]
ret.extend(self.fields)
for reducer in self.reducers:
ret += ['REDUCE', reducer.NAME, str(len(reducer.args))]
ret.extend(reducer.args)
if reducer._alias:
if reducer._alias is not None:
ret += ['AS', reducer._alias]
return ret

class Projection(object):
"""
This object automatically created in the `AggregateRequest.apply()`
"""

def __init__(self, projector, alias=None ):

self.alias = alias
self.projector = projector

def build_args(self):
ret = ['APPLY', self.projector]
if self.alias is not None:
ret += ['AS', self.alias]

return ret

class SortBy(object):
"""
This object automatically created in the `AggregateRequest.sort_by()`
"""

def __init__(self, fields, max=0):
self.fields = fields
self.max = max



def build_args(self):
fields_args = []
for f in self.fields:
if isinstance(f, SortDirection):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]

ret = ['SORTBY', str(len(fields_args))]
ret.extend(fields_args)
if self.max > 0:
ret += ['MAX', str(self.max)]

return ret


class AggregateRequest(object):
"""
Expand All @@ -127,11 +170,9 @@ def __init__(self, query='*'):
return the object itself, making them useful for chaining.
"""
self._query = query
self._groups = []
self._projections = []
self._aggregateplan = []
self._loadfields = []
self._limit = Limit()
self._sortby = []
self._max = 0
self._with_schema = False
self._verbatim = False
Expand Down Expand Up @@ -162,7 +203,7 @@ def group_by(self, fields, *reducers):
`aggregation` module.
"""
group = Group(fields, reducers)
self._groups.append(group)
self._aggregateplan.extend(group.build_args())

return self

Expand All @@ -177,7 +218,8 @@ def apply(self, **kwexpr):
expression itself, for example `apply(square_root="sqrt(@foo)")`
"""
for alias, expr in kwexpr.items():
self._projections.append([alias, expr])
projection = Projection(expr, alias )
self._aggregateplan.extend(projection.build_args())

return self

Expand Down Expand Up @@ -224,10 +266,7 @@ def limit(self, offset, num):

"""
limit = Limit(offset, num)
if self._groups:
self._groups[-1].limit = limit
else:
self._limit = limit
self._limit = limit
return self

def sort_by(self, *fields, **kwargs):
Expand Down Expand Up @@ -258,16 +297,34 @@ def sort_by(self, *fields, **kwargs):
.sort_by(Desc('@paid'), max=10)
```
"""
self._max = kwargs.get('max', 0)
if isinstance(fields, (string_types, SortDirection)):
fields = [fields]
for f in fields:
if isinstance(f, SortDirection):
self._sortby += [f.field, f.DIRSTRING]
else:
self._sortby.append(f)

max = kwargs.get('max', 0)
sortby = SortBy(fields, max)

self._aggregateplan.extend(sortby.build_args())
return self

def filter(self, expressions):
"""
Specify filter for post-query results using predicates relating to values in the result set.

### Parameters

- **fields**: Fields to group by. This can either be a single string,
or a list of strings.
"""
if isinstance(expressions, (string_types)):
expressions = [expressions]

for expression in expressions:
self._aggregateplan.extend(['FILTER', expression])

return self



def with_schema(self):
"""
If set, the `schema` property will contain a list of `[field, type]`
Expand Down Expand Up @@ -312,18 +369,8 @@ def build_args(self):
ret.append('LOAD')
ret.append(str(len(self._loadfields)))
ret.extend(self._loadfields)
for group in self._groups:
ret += ['GROUPBY'] + group.build_args() + group.limit.build_args()
for alias, projector in self._projections:
ret += ['APPLY', projector]
if alias:
ret += ['AS', alias]

if self._sortby:
ret += ['SORTBY', str(len(self._sortby))]
ret += self._sortby
if self._max:
ret += ['MAX', str(self._max)]

ret.extend(self._aggregateplan)

ret += self._limit.build_args()

Expand Down
45 changes: 37 additions & 8 deletions test/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from unittest import TestCase
import unittest
import redisearch.aggregation as a
import redisearch.querystring as q
import redisearch.reducers as r

class QueryBuilderTest(TestCase):
class QueryBuilderTest(unittest.TestCase):
def testBetween(self):
b = q.between(1, 10)
self.assertEqual('[1 10]', str(b))
Expand Down Expand Up @@ -42,16 +42,16 @@ def testGroup(self):
# Single field, single reducer
g = a.Group('foo', r.count())
ret = g.build_args()
self.assertEqual(['1', 'foo', 'REDUCE', 'COUNT', '0'], ret)
self.assertEqual(['GROUPBY', '1', 'foo', 'REDUCE', 'COUNT', '0'], ret)

# Multiple fields, single reducer
g = a.Group(['foo', 'bar'], r.count())
self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'],
self.assertEqual(['GROUPBY', '2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'],
g.build_args())

# Multiple fields, multiple reducers
g = a.Group(['foo', 'bar'], [r.count(), r.count_distinct('@fld1')])
self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'],
self.assertEqual(['GROUPBY', '2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'],
g.build_args())

def testAggRequest(self):
Expand All @@ -62,13 +62,38 @@ def testAggRequest(self):
req = a.AggregateRequest().group_by('@foo', r.count())
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'], req.build_args())

# Test with group_by and alias on reducer
req = a.AggregateRequest().group_by('@foo', r.count().alias('foo_count'))
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'AS', 'foo_count'], req.build_args())

# Test with limit
req = a.AggregateRequest().\
group_by('@foo', r.count()).\
req = a.AggregateRequest(). \
group_by('@foo', r.count()). \
sort_by('@foo')
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '1',
'@foo'], req.build_args())

# Test with apply
req = a.AggregateRequest(). \
apply(foo="@bar / 2"). \
group_by('@foo', r.count())

self.assertEqual(['*', 'APPLY', '@bar / 2', 'AS', 'foo', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'],
req.build_args())

# Test with filter
req = a.AggregateRequest().group_by('@foo', r.count()).filter( "@foo=='bar'")
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'FILTER', "@foo=='bar'" ], req.build_args())

# Test with filter on different state of the pipeline
req = a.AggregateRequest().filter("@foo=='bar'").group_by('@foo', r.count())
self.assertEqual(['*', 'FILTER', "@foo=='bar'", 'GROUPBY', '1', '@foo','REDUCE', 'COUNT', '0' ], req.build_args())

# Test with filter on different state of the pipeline
req = a.AggregateRequest().filter(["@foo=='bar'","@foo2=='bar2'"]).group_by('@foo', r.count())
self.assertEqual(['*', 'FILTER', "@foo=='bar'", 'FILTER', "@foo2=='bar2'", 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'],
req.build_args())

# Test with sort_by
req = a.AggregateRequest().group_by('@foo', r.count()).sort_by('@date')
# print req.build_args()
Expand Down Expand Up @@ -105,4 +130,8 @@ def test_reducers(self):
self.assertEqual(('f1', 'BY', 'f2', 'ASC'), r.first_value('f1', a.Asc('f2')).args)
self.assertEqual(('f1', 'BY', 'f1', 'ASC'), r.first_value('f1', a.Asc).args)

self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args)
self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args)

if __name__ == '__main__':

unittest.main()