Skip to content

Commit d687544

Browse files
Add Enumerable.TryGetNonEnumeratedCount (Implements #27183) (#48239)
* implement Enumerable.TryGetEnumeratingCount * address feedback * update consistency tests * Replace EnumerableHelpers.TryGetCount with new method * Rename to method name as approved * make method is renamed in all projects
1 parent 2cae582 commit d687544

File tree

9 files changed

+154
-40
lines changed

9 files changed

+154
-40
lines changed

src/libraries/Common/src/System/Collections/Generic/EnumerableHelpers.Linq.cs

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,6 @@ namespace System.Collections.Generic
1111
/// </summary>
1212
internal static partial class EnumerableHelpers
1313
{
14-
/// <summary>
15-
/// Tries to get the count of the enumerable cheaply.
16-
/// </summary>
17-
/// <typeparam name="T">The element type of the source enumerable.</typeparam>
18-
/// <param name="source">The enumerable to count.</param>
19-
/// <param name="count">The count of the enumerable, if it could be obtained cheaply.</param>
20-
/// <returns><c>true</c> if the enumerable could be counted cheaply; otherwise, <c>false</c>.</returns>
21-
internal static bool TryGetCount<T>(IEnumerable<T> source, out int count)
22-
{
23-
Debug.Assert(source != null);
24-
25-
if (source is ICollection<T> collection)
26-
{
27-
count = collection.Count;
28-
return true;
29-
}
30-
31-
if (source is IIListProvider<T> provider)
32-
{
33-
return (count = provider.GetCount(onlyIfCheap: true)) >= 0;
34-
}
35-
36-
count = -1;
37-
return false;
38-
}
39-
4014
/// <summary>
4115
/// Copies items from an enumerable to an array.
4216
/// </summary>

src/libraries/Common/src/System/Collections/Generic/SparseArrayBuilder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ public void Reserve(int count)
190190
public bool ReserveOrAdd(IEnumerable<T> items)
191191
{
192192
int itemCount;
193-
if (EnumerableHelpers.TryGetCount(items, out itemCount))
193+
if (System.Linq.Enumerable.TryGetNonEnumeratedCount(items, out itemCount))
194194
{
195195
if (itemCount > 0)
196196
{

src/libraries/System.Linq.Queryable/tests/Queryable.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,17 @@ public static void MatchSequencePattern()
121121
typeof(Enumerable),
122122
typeof(Queryable),
123123
new [] {
124-
"ToLookup",
125-
"ToDictionary",
126-
"ToArray",
127-
"AsEnumerable",
128-
"ToList",
124+
nameof(Enumerable.ToLookup),
125+
nameof(Enumerable.ToDictionary),
126+
nameof(Enumerable.ToArray),
127+
nameof(Enumerable.AsEnumerable),
128+
nameof(Enumerable.ToList),
129+
nameof(Enumerable.Append),
130+
nameof(Enumerable.Prepend),
131+
nameof(Enumerable.ToHashSet),
132+
nameof(Enumerable.TryGetNonEnumeratedCount),
129133
"Fold",
130134
"LeftJoin",
131-
"Append",
132-
"Prepend",
133-
"ToHashSet"
134135
}
135136
);
136137

@@ -140,7 +141,7 @@ public static void MatchSequencePattern()
140141
typeof(Queryable),
141142
typeof(Enumerable),
142143
new [] {
143-
"AsQueryable"
144+
nameof(Queryable.AsQueryable)
144145
}
145146
);
146147

src/libraries/System.Linq/Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
<PropertyGroup>
44
<StrongNameKeyId>Microsoft</StrongNameKeyId>
55
</PropertyGroup>
6-
</Project>
6+
</Project>

src/libraries/System.Linq/ref/System.Linq.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ public static System.Collections.Generic.IEnumerable<
189189
public static System.Linq.ILookup<TKey, TSource> ToLookup<TSource, TKey>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
190190
public static System.Linq.ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Func<TSource, TElement> elementSelector) { throw null; }
191191
public static System.Linq.ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Func<TSource, TElement> elementSelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
192+
public static bool TryGetNonEnumeratedCount<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, out int count) { throw null; }
192193
public static System.Collections.Generic.IEnumerable<TSource> Union<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second) { throw null; }
193194
public static System.Collections.Generic.IEnumerable<TSource> Union<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
194195
public static System.Collections.Generic.IEnumerable<TSource> Where<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }

src/libraries/System.Linq/src/System/Linq/Concat.SpeedOpt.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ private sealed partial class Concat2Iterator<TSource> : ConcatIterator<TSource>
1313
public override int GetCount(bool onlyIfCheap)
1414
{
1515
int firstCount, secondCount;
16-
if (!EnumerableHelpers.TryGetCount(_first, out firstCount))
16+
if (!_first.TryGetNonEnumeratedCount(out firstCount))
1717
{
1818
if (onlyIfCheap)
1919
{
@@ -23,7 +23,7 @@ public override int GetCount(bool onlyIfCheap)
2323
firstCount = _first.Count();
2424
}
2525

26-
if (!EnumerableHelpers.TryGetCount(_second, out secondCount))
26+
if (!_second.TryGetNonEnumeratedCount(out secondCount))
2727
{
2828
if (onlyIfCheap)
2929
{

src/libraries/System.Linq/src/System/Linq/Count.cs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,59 @@ public static int Count<TSource>(this IEnumerable<TSource> source, Func<TSource,
7272
return count;
7373
}
7474

75+
/// <summary>
76+
/// Attempts to determine the number of elements in a sequence without forcing an enumeration.
77+
/// </summary>
78+
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
79+
/// <param name="source">A sequence that contains elements to be counted.</param>
80+
/// <param name="count">
81+
/// When this method returns, contains the count of <paramref name="source" /> if successful,
82+
/// or zero if the method failed to determine the count.</param>
83+
/// <returns>
84+
/// <see langword="true" /> if the count of <paramref name="source"/> can be determined without enumeration;
85+
/// otherwise, <see langword="false" />.
86+
/// </returns>
87+
/// <remarks>
88+
/// The method performs a series of type tests, identifying common subtypes whose
89+
/// count can be determined without enumerating; this includes <see cref="ICollection{T}"/>,
90+
/// <see cref="ICollection"/> as well as internal types used in the LINQ implementation.
91+
///
92+
/// The method is typically a constant-time operation, but ultimately this depends on the complexity
93+
/// characteristics of the underlying collection implementation.
94+
/// </remarks>
95+
public static bool TryGetNonEnumeratedCount<TSource>(this IEnumerable<TSource> source, out int count)
96+
{
97+
if (source == null)
98+
{
99+
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
100+
}
101+
102+
if (source is ICollection<TSource> collectionoft)
103+
{
104+
count = collectionoft.Count;
105+
return true;
106+
}
107+
108+
if (source is IIListProvider<TSource> listProv)
109+
{
110+
int c = listProv.GetCount(onlyIfCheap: true);
111+
if (c >= 0)
112+
{
113+
count = c;
114+
return true;
115+
}
116+
}
117+
118+
if (source is ICollection collection)
119+
{
120+
count = collection.Count;
121+
return true;
122+
}
123+
124+
count = 0;
125+
return false;
126+
}
127+
75128
public static long LongCount<TSource>(this IEnumerable<TSource> source)
76129
{
77130
if (source == null)

src/libraries/System.Linq/tests/ConsistencyTests.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ private static IEnumerable<string> GetExcludedMethods()
4141
nameof(Enumerable.ToArray),
4242
nameof(Enumerable.AsEnumerable),
4343
nameof(Enumerable.ToList),
44+
nameof(Enumerable.ToHashSet),
45+
nameof(Enumerable.TryGetNonEnumeratedCount),
4446
"Fold",
4547
"LeftJoin",
46-
"ToHashSet"
4748
};
4849

4950
return result;

src/libraries/System.Linq/tests/CountTests.cs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,89 @@ public void NullPredicate_ThrowsArgumentNullException()
126126
Func<int, bool> predicate = null;
127127
AssertExtensions.Throws<ArgumentNullException>("predicate", () => Enumerable.Range(0, 3).Count(predicate));
128128
}
129+
130+
[Fact]
131+
public void NonEnumeratingCount_NullSource_ThrowsArgumentNullException()
132+
{
133+
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IEnumerable<int>)null).TryGetNonEnumeratedCount(out _));
134+
}
135+
136+
[Theory]
137+
[MemberData(nameof(NonEnumeratingCount_SupportedEnumerables))]
138+
public void NonEnumeratingCount_SupportedEnumerables_ShouldReturnExpectedCount<T>(int expectedCount, IEnumerable<T> source)
139+
{
140+
Assert.True(source.TryGetNonEnumeratedCount(out int actualCount));
141+
Assert.Equal(expectedCount, actualCount);
142+
}
143+
144+
[Theory]
145+
[MemberData(nameof(NonEnumeratingCount_UnsupportedEnumerables))]
146+
public void NonEnumeratingCount_UnsupportedEnumerables_ShouldReturnFalse<T>(IEnumerable<T> source)
147+
{
148+
Assert.False(source.TryGetNonEnumeratedCount(out int actualCount));
149+
Assert.Equal(0, actualCount);
150+
}
151+
152+
[Fact]
153+
public void NonEnumeratingCount_ShouldNotEnumerateSource()
154+
{
155+
bool isEnumerated = false;
156+
Assert.False(Source().TryGetNonEnumeratedCount(out int count));
157+
Assert.Equal(0, count);
158+
Assert.False(isEnumerated);
159+
160+
IEnumerable<int> Source()
161+
{
162+
isEnumerated = true;
163+
yield return 42;
164+
}
165+
}
166+
167+
public static IEnumerable<object[]> NonEnumeratingCount_SupportedEnumerables()
168+
{
169+
yield return WrapArgs(4, new int[]{ 1, 2, 3, 4 });
170+
yield return WrapArgs(4, new List<int>(new int[] { 1, 2, 3, 4 }));
171+
yield return WrapArgs(4, new Stack<int>(new int[] { 1, 2, 3, 4 }));
172+
173+
yield return WrapArgs(0, Enumerable.Empty<string>());
174+
175+
if (PlatformDetection.IsSpeedOptimized)
176+
{
177+
yield return WrapArgs(100, Enumerable.Range(1, 100));
178+
yield return WrapArgs(80, Enumerable.Repeat(1, 80));
179+
yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1));
180+
yield return WrapArgs(4, new int[] { 1, 2, 3, 4 }.Select(x => x + 1));
181+
yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1));
182+
yield return WrapArgs(7, Enumerable.Range(1, 20).ToLookup(x => x % 7));
183+
yield return WrapArgs(20, Enumerable.Range(1, 20).Reverse());
184+
yield return WrapArgs(20, Enumerable.Range(1, 20).OrderBy(x => -x));
185+
yield return WrapArgs(20, Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10)));
186+
}
187+
188+
static object[] WrapArgs<T>(int expectedCount, IEnumerable<T> source) => new object[] { expectedCount, source };
189+
}
190+
191+
public static IEnumerable<object[]> NonEnumeratingCount_UnsupportedEnumerables()
192+
{
193+
yield return WrapArgs(Enumerable.Range(1, 100).Where(x => x % 2 == 0));
194+
yield return WrapArgs(Enumerable.Range(1, 100).GroupBy(x => x % 2 == 0));
195+
yield return WrapArgs(new Stack<int>(new int[] { 1, 2, 3, 4 }).Select(x => x + 1));
196+
yield return WrapArgs(Enumerable.Range(1, 100).Distinct());
197+
198+
if (!PlatformDetection.IsSpeedOptimized)
199+
{
200+
yield return WrapArgs(Enumerable.Range(1, 100));
201+
yield return WrapArgs(Enumerable.Repeat(1, 80));
202+
yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1));
203+
yield return WrapArgs(new int[] { 1, 2, 3, 4 }.Select(x => x + 1));
204+
yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1));
205+
yield return WrapArgs(Enumerable.Range(1, 20).ToLookup(x => x % 7));
206+
yield return WrapArgs(Enumerable.Range(1, 20).Reverse());
207+
yield return WrapArgs(Enumerable.Range(1, 20).OrderBy(x => -x));
208+
yield return WrapArgs(Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10)));
209+
}
210+
211+
static object[] WrapArgs<T>(IEnumerable<T> source) => new object[] { source };
212+
}
129213
}
130214
}

0 commit comments

Comments
 (0)