Skip to content
Prev Previous commit
Next Next commit
[add] added supported for filter expressions on aggregations. include…
… more examples on test_builder.py
  • Loading branch information
filipecosta90 committed Oct 3, 2019
commit 7041f4e4f84c70efcaefdc91a802bdd6f59ed89a
25 changes: 23 additions & 2 deletions redisearch/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def build_args(self):
for reducer in self.reducers:
ret += ['REDUCE', reducer.NAME, str(len(reducer.args))]
ret.extend(reducer.args)
if reducer._alias:
if reducer._alias is not None:
ret += ['AS', reducer._alias]
return ret

Expand Down Expand Up @@ -172,6 +172,7 @@ def __init__(self, query='*'):
self._query = query
self._aggregateplan = []
self._loadfields = []
self._filters = []
self._limit = Limit()
self._max = 0
self._with_schema = False
Expand Down Expand Up @@ -218,7 +219,7 @@ def apply(self, **kwexpr):
expression itself, for example `apply(square_root="sqrt(@foo)")`
"""
for alias, expr in kwexpr.items():
projection = Projection(alias, expr)
projection = Projection(expr, alias )
self._aggregateplan.extend(projection.build_args())

return self
Expand Down Expand Up @@ -306,6 +307,22 @@ def sort_by(self, *fields, **kwargs):
self._aggregateplan.extend(sortby.build_args())
return self

def filter(self, expressions):
"""
Specify filter for post-query results using predicates relating to values in the result set.

### Parameters

- **fields**: Fields to group by. This can either be a single string,
or a list of strings.
"""
if isinstance(expressions, (string_types)):
expressions = [expressions]

self._filters.extend(expressions)

return self

def with_schema(self):
"""
If set, the `schema` property will contain a list of `[field, type]`
Expand Down Expand Up @@ -353,6 +370,10 @@ def build_args(self):

ret.extend(self._aggregateplan)

for expr in self._filters:
ret.append('FILTER')
ret.append(expr)

ret += self._limit.build_args()

return ret
Expand Down
8 changes: 8 additions & 0 deletions test/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def testAggRequest(self):
req = a.AggregateRequest().group_by('@foo', r.count())
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'], req.build_args())

# Test with group_by and alias on reducer
req = a.AggregateRequest().group_by('@foo', r.count().alias('foo_count'))
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'AS', 'foo_count'], req.build_args())

# Test with limit
req = a.AggregateRequest(). \
group_by('@foo', r.count()). \
Expand All @@ -77,6 +81,10 @@ def testAggRequest(self):
self.assertEqual(['*', 'APPLY', '@bar / 2', 'AS', 'foo', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'],
req.build_args())

# Test with filter
req = a.AggregateRequest().group_by('@foo', r.count()).filter( "@foo=='bar'")
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'FILTER', "@foo=='bar'" ], req.build_args())

# Test with sort_by
req = a.AggregateRequest().group_by('@foo', r.count()).sort_by('@date')
# print req.build_args()
Expand Down