Skip to content

Commit fc052f4

Browse files
committed
CSHARP-4880: Support SequenceEqual in aggregation expressions.
1 parent 25dd800 commit fc052f4

File tree

4 files changed

+133
-0
lines changed

4 files changed

+133
-0
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ internal static class QueryableMethod
9393
private static readonly MethodInfo __selectManyWithCollectionSelectorTakingIndexAndResultSelector;
9494
private static readonly MethodInfo __selectManyWithSelectorTakingIndex;
9595
private static readonly MethodInfo __selectWithSelectorTakingIndex;
96+
private static readonly MethodInfo __sequenceEqual;
9697
private static readonly MethodInfo __single;
9798
private static readonly MethodInfo __singleOrDefault;
9899
private static readonly MethodInfo __singleOrDefaultWithPredicate;
@@ -198,6 +199,7 @@ static QueryableMethod()
198199
__selectManyWithCollectionSelectorTakingIndexAndResultSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, int, IEnumerable<object>>> collectionSelector, Expression<Func<object, object, object>> resultSelector) => source.SelectMany(collectionSelector, resultSelector));
199200
__selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, int, IEnumerable<object>>> selector) => source.SelectMany(selector));
200201
__selectWithSelectorTakingIndex = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, int, object>> selector) => source.Select(selector));
202+
__sequenceEqual = ReflectionInfo.Method((IQueryable<object> source1, IEnumerable<object> source2) => source1.SequenceEqual(source2));
201203
__single = ReflectionInfo.Method((IQueryable<object> source) => source.Single());
202204
__singleOrDefault = ReflectionInfo.Method((IQueryable<object> source) => source.SingleOrDefault());
203205
__singleOrDefaultWithPredicate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.SingleOrDefault(predicate));
@@ -302,6 +304,7 @@ static QueryableMethod()
302304
public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector;
303305
public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex;
304306
public static MethodInfo SelectWithSelectorTakingIndex => __selectWithSelectorTakingIndex;
307+
public static MethodInfo SequenceEqual => __sequenceEqual;
305308
public static MethodInfo Single => __single;
306309
public static MethodInfo SingleOrDefault => __singleOrDefault;
307310
public static MethodInfo SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
7373
case "Round": return RoundMethodToAggregationExpressionTranslator.Translate(context, expression);
7474
case "Select": return SelectMethodToAggregationExpressionTranslator.Translate(context, expression);
7575
case "SelectMany": return SelectManyMethodToAggregationExpressionTranslator.Translate(context, expression);
76+
case "SequenceEqual": return SequenceEqualMethodToAggregationExpressionTranslator.Translate(context, expression);
7677
case "SetEquals": return SetEqualsMethodToAggregationExpressionTranslator.Translate(context, expression);
7778
case "Shift": return ShiftMethodToAggregationExpressionTranslator.Translate(context, expression);
7879
case "Split": return SplitMethodToAggregationExpressionTranslator.Translate(context, expression);
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq.Expressions;
17+
using MongoDB.Bson.Serialization.Serializers;
18+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
19+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
21+
22+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
23+
{
24+
internal static class SequenceEqualMethodToAggregationExpressionTranslator
25+
{
26+
public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
27+
{
28+
var method = expression.Method;
29+
var arguments = expression.Arguments;
30+
31+
if (method.IsOneOf(EnumerableMethod.SequenceEqual, QueryableMethod.SequenceEqual))
32+
{
33+
var firstExpression = arguments[0];
34+
var secondExpression = arguments[1];
35+
36+
var firstTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, firstExpression);
37+
var secondTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, secondExpression);
38+
39+
var (firstVarBinding, firstAst) = AstExpression.UseVarIfNotSimple("first", firstTranslation.Ast);
40+
var (secondVarBinding, secondAst) = AstExpression.UseVarIfNotSimple("second", secondTranslation.Ast);
41+
var pairVar = AstExpression.Var("pair");
42+
43+
var ast = AstExpression.Let(
44+
firstVarBinding,
45+
secondVarBinding,
46+
@in : AstExpression.And(
47+
AstExpression.Eq(AstExpression.Type(firstAst), "array"),
48+
AstExpression.Eq(AstExpression.Type(secondAst), "array"),
49+
AstExpression.Eq(AstExpression.Size(firstAst), AstExpression.Size(secondAst)),
50+
AstExpression.AllElementsTrue(
51+
AstExpression.Map(
52+
input: AstExpression.Zip([firstAst, secondAst]),
53+
@as: pairVar,
54+
@in : AstExpression.Eq(AstExpression.ArrayElemAt(pairVar, 0), AstExpression.ArrayElemAt(pairVar, 1)))))
55+
);
56+
57+
return new AggregationExpression(expression, ast, new BooleanSerializer());
58+
}
59+
60+
throw new ExpressionNotSupportedException(expression);
61+
}
62+
}
63+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using FluentAssertions;
18+
using MongoDB.Bson;
19+
using MongoDB.TestHelpers.XunitExtensions;
20+
using Xunit;
21+
22+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
23+
{
24+
public class CSharp4880Tests : Linq3IntegrationTest
25+
{
26+
[Theory]
27+
[ParameterAttributeData]
28+
public void Select_SequenceEqual_should_work(
29+
[Values(false, true)] bool withNestedAsQueryable)
30+
{
31+
var collection = GetCollection();
32+
33+
var queryable = withNestedAsQueryable ?
34+
collection.AsQueryable().Select(x => x.A.AsQueryable().SequenceEqual(x.B)) :
35+
collection.AsQueryable().Select(x => x.A.SequenceEqual(x.B));
36+
37+
var stages = Translate(collection, queryable);
38+
AssertStages(stages, "{ $project : { _v : { $and : [{ $eq : [{ $type : '$A' }, 'array'] }, { $eq : [{ $type : '$B' }, 'array'] }, { $eq : [{ $size : '$A' }, { $size : '$B' }] }, { $allElementsTrue : { $map : { input : { $zip : { inputs : ['$A', '$B'] } }, as : 'pair', in: { $eq : [{ $arrayElemAt : ['$$pair', 0] }, { $arrayElemAt : ['$$pair', 1] }] } } } }] }, _id : 0 } }");
39+
40+
var results = queryable.ToList();
41+
results.Should().Equal(false, false, false, false, true, false, false);
42+
}
43+
44+
private IMongoCollection<C> GetCollection()
45+
{
46+
var collection = GetCollection<C>("test");
47+
CreateCollection(
48+
collection.Database.GetCollection<BsonDocument>(collection.CollectionNamespace.CollectionName),
49+
BsonDocument.Parse("{ _id : 1, A : null, B : null }"),
50+
BsonDocument.Parse("{ _id : 2, A : null, B : [1, 2, 3] }"),
51+
BsonDocument.Parse("{ _id : 3, A : [1, 2, 3], B : null }"),
52+
BsonDocument.Parse("{ _id : 4, A : [1, 2, 3], B : [1, 2] }"),
53+
BsonDocument.Parse("{ _id : 5, A : [1, 2, 3], B : [1, 2, 3] }"),
54+
BsonDocument.Parse("{ _id : 6, A : [1, 2, 3], B : [4, 5, 6] }"),
55+
BsonDocument.Parse("{ _id : 7, A : [1, 2, 3], B : [1, 2, 3, 4] }"));
56+
return collection;
57+
}
58+
59+
private class C
60+
{
61+
public int Id { get; set; }
62+
public int[] A { get; set; }
63+
public int[] B { get; set; }
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)