From 7e624e3bac3962b0e217b19d579257620003faa4 Mon Sep 17 00:00:00 2001 From: Emik Date: Thu, 16 May 2024 14:29:11 +0200 Subject: [PATCH] Use iterator classes for Enumerable.Index/Select((T, int)) (#102252) * Allows both methods to have their lengths obtained from Enumerable.TryGetNonEnumeratedCount(). * Improves performance of Enumerable.Index with an Array/IList/List source. --- .../System.Linq/src/System.Linq.csproj | 1 + .../src/System/Linq/Index.SpeedOpt.cs | 213 ++++++++++++++++++ .../System.Linq/src/System/Linq/Index.cs | 23 +- .../src/System/Linq/Select.SpeedOpt.cs | 150 ++++++++++++ .../System.Linq/src/System/Linq/Select.cs | 74 ++++-- .../System.Linq/src/System/Linq/Utilities.cs | 14 ++ 6 files changed, 444 insertions(+), 31 deletions(-) create mode 100644 src/libraries/System.Linq/src/System/Linq/Index.SpeedOpt.cs diff --git a/src/libraries/System.Linq/src/System.Linq.csproj b/src/libraries/System.Linq/src/System.Linq.csproj index 68b88631587ace..60d926e02917ad 100644 --- a/src/libraries/System.Linq/src/System.Linq.csproj +++ b/src/libraries/System.Linq/src/System.Linq.csproj @@ -24,6 +24,7 @@ + diff --git a/src/libraries/System.Linq/src/System/Linq/Index.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Index.SpeedOpt.cs new file mode 100644 index 00000000000000..8393e0c3d6a258 --- /dev/null +++ b/src/libraries/System.Linq/src/System/Linq/Index.SpeedOpt.cs @@ -0,0 +1,213 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using static System.Linq.Utilities; + +namespace System.Linq +{ + public static partial class Enumerable + { + private sealed partial class IEnumerableIndexIterator : Iterator<(int Index, TSource Item)> + { + private readonly IEnumerable _source; + private int _index; + private IEnumerator? _enumerator; + + public IEnumerableIndexIterator(IEnumerable source) + { + Debug.Assert(source is not null); + _source = source; + } + + private protected override Iterator<(int Index, TSource Item)> Clone() => + new IEnumerableIndexIterator(_source); + + public override void Dispose() + { + if (_enumerator is not null) + { + _enumerator.Dispose(); + _enumerator = null; + } + + base.Dispose(); + } + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _enumerator = _source.GetEnumerator(); + _index = -1; + _state = 2; + goto case 2; + case 2: + Debug.Assert(_enumerator is not null); + + if (_enumerator.MoveNext()) + { + _current = (checked(++_index), _enumerator.Current); + return true; + } + + Dispose(); + break; + } + + return false; + } + + public override IEnumerable Select(Func<(int Index, TSource Item), TResult2> selector) => + new IEnumerableSelect2Iterator(_source, CombineSelectors((TSource x, int i) => (i, x), selector)); + + public override (int Index, TSource Item)[] ToArray() + { + int index = -1; + + if (_source.TryGetNonEnumeratedCount(out int known)) + { + var array = new (int Index, TSource Item)[known]; + + foreach (TSource item in _source) + { + array[checked(++index)] = (index, item); + } + + return array; + } + + SegmentedArrayBuilder<(int Index, TSource Item)>.ScratchBuffer scratch = default; + SegmentedArrayBuilder<(int Index, TSource Item)> builder = new(scratch); + + foreach (TSource item in _source) + { + builder.Add((checked(++index), item)); + } + + (int Index, TSource Item)[] result = builder.ToArray(); + builder.Dispose(); + return result; + } + + public override List<(int Index, TSource Item)> ToList() + { + List<(int Index, TSource Item)> list = _source.TryGetNonEnumeratedCount(out int known) ? new(known) : []; + int index = -1; + + foreach (TSource item in _source) + { + list.Add((checked(++index), item)); + } + + return list; + } + + public override int GetCount(bool onlyIfCheap) + { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) + { + return _source.TryGetNonEnumeratedCount(out int known) ? known : -1; + } + + int count = 0; + + foreach (TSource item in _source) + { + checked + { + count++; + } + } + + return count; + } + + public override (int Index, TSource Item) TryGetElementAt(int index, out bool found) + { + if (_source is Iterator iterator) + { + return iterator.TryGetElementAt(index, out found) is var element && found ? (0, element!) : default; + } + + if (index >= 0) + { + IEnumerator e = _source.GetEnumerator(); + int enumeratorIndex = -1; + + try + { + while (e.MoveNext()) + { + if (index == 0) + { + found = true; + return (checked(++enumeratorIndex), e.Current); + } + + index--; + } + } + finally + { + (e as IDisposable)?.Dispose(); + } + } + + found = false; + return default; + } + + public override (int Index, TSource Item) TryGetFirst(out bool found) + { + if (_source is Iterator iterator) + { + return iterator.TryGetFirst(out found) is var first && found ? (0, first!) : default; + } + + using IEnumerator e = _source.GetEnumerator(); + + if (e.MoveNext()) + { + found = true; + return (0, e.Current); + } + + found = false; + return default; + } + + public override (int Index, TSource Item) TryGetLast(out bool found) + { + if (_source is Iterator iterator && iterator.GetCount(true) is not -1 and var count) + { + return iterator.TryGetLast(out found) is var last && found ? (count - 1, last!) : default; + } + + using IEnumerator e = _source.GetEnumerator(); + + if (e.MoveNext()) + { + found = true; + TSource lastElement = e.Current; + int lastIndex = -1; + + while (e.MoveNext()) + { + lastElement = e.Current; + lastIndex++; + } + + return (lastIndex, lastElement); + } + + found = false; + return default; + } + } + } +} diff --git a/src/libraries/System.Linq/src/System/Linq/Index.cs b/src/libraries/System.Linq/src/System/Linq/Index.cs index 49339b03d1ad39..11658cdafadcf4 100644 --- a/src/libraries/System.Linq/src/System/Linq/Index.cs +++ b/src/libraries/System.Linq/src/System/Linq/Index.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; +using static System.Linq.Utilities; namespace System.Linq { @@ -22,22 +24,11 @@ public static partial class Enumerable { return []; } - - return IndexIterator(source); - } - - private static IEnumerable<(int Index, TSource Item)> IndexIterator(IEnumerable source) - { - int index = -1; - foreach (TSource element in source) - { - checked - { - index++; - } - - yield return (index, element); - } +#if OPTIMIZE_FOR_SIZE + return new IEnumerableSelect2Iterator(source, (x, i) => (i, x)); +#else + return new IEnumerableIndexIterator(source); +#endif } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs index f491f1f0de015a..afb239ebeb16bb 100644 --- a/src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs @@ -129,6 +129,156 @@ public override int GetCount(bool onlyIfCheap) } } + private sealed partial class IEnumerableSelect2Iterator + { + public override TResult[] ToArray() + { + Func selector = _selector; + int index = -1; + + if (_source.TryGetNonEnumeratedCount(out int known)) + { + var array = new TResult[known]; + + foreach (TSource item in _source) + { + array[checked(++index)] = selector(item, index); + } + + return array; + } + + SegmentedArrayBuilder.ScratchBuffer scratch = default; + SegmentedArrayBuilder builder = new(scratch); + + foreach (TSource item in _source) + { + builder.Add(selector(item, checked(++index))); + } + + TResult[] result = builder.ToArray(); + builder.Dispose(); + return result; + } + + public override List ToList() + { + List list = _source.TryGetNonEnumeratedCount(out int known) ? new(known) : []; + Func selector = _selector; + int index = -1; + + foreach (TSource item in _source) + { + list.Add(selector(item, checked(++index))); + } + + return list; + } + + public override int GetCount(bool onlyIfCheap) + { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) + { + return _source.TryGetNonEnumeratedCount(out int known) ? known : -1; + } + + int count = 0; + + foreach (TSource item in _source) + { + _selector(item, checked(count++)); + } + + return count; + } + + public override TResult? TryGetElementAt(int index, out bool found) + { + if (_source is Iterator iterator) + { + return iterator.TryGetElementAt(index, out found) is var element && found + ? _selector(element!, 0) + : default; + } + + if (index >= 0) + { + IEnumerator e = _source.GetEnumerator(); + int enumeratorIndex = -1; + + try + { + while (e.MoveNext()) + { + if (index == 0) + { + found = true; + return _selector(e.Current, checked(++enumeratorIndex)); + } + + index--; + } + } + finally + { + (e as IDisposable)?.Dispose(); + } + } + + found = false; + return default; + } + + public override TResult? TryGetFirst(out bool found) + { + if (_source is Iterator iterator) + { + return iterator.TryGetFirst(out found) is var first && found ? _selector(first!, 0) : default; + } + + using IEnumerator e = _source.GetEnumerator(); + + if (e.MoveNext()) + { + found = true; + return _selector(e.Current, 0); + } + + found = false; + return default; + } + + public override TResult? TryGetLast(out bool found) + { + if (_source is Iterator iterator && iterator.GetCount(true) is not -1 and var count) + { + return iterator.TryGetLast(out found) is var last && found ? _selector(last!, count - 1) : default; + } + + using IEnumerator e = _source.GetEnumerator(); + + if (e.MoveNext()) + { + found = true; + TSource lastElement = e.Current; + int lastIndex = -1; + + while (e.MoveNext()) + { + lastElement = e.Current; + lastIndex++; + } + + return _selector(lastElement, lastIndex); + } + + found = false; + return default; + } + } + private sealed partial class ArraySelectIterator { public override TResult[] ToArray() diff --git a/src/libraries/System.Linq/src/System/Linq/Select.cs b/src/libraries/System.Linq/src/System/Linq/Select.cs index 37a41c450f873c..683ee62ab0fcf2 100644 --- a/src/libraries/System.Linq/src/System/Linq/Select.cs +++ b/src/libraries/System.Linq/src/System/Linq/Select.cs @@ -68,21 +68,7 @@ public static IEnumerable Select(this IEnumerable SelectIterator(IEnumerable source, Func selector) - { - int index = -1; - foreach (TSource element in source) - { - checked - { - index++; - } - - yield return selector(element, index); - } + return new IEnumerableSelect2Iterator(source, selector); } /// @@ -145,6 +131,64 @@ public override IEnumerable Select(Func s new IEnumerableSelectIterator(_source, CombineSelectors(_selector, selector)); } + private sealed partial class IEnumerableSelect2Iterator : Iterator + { + private readonly IEnumerable _source; + private readonly Func _selector; + private int _index; + private IEnumerator? _enumerator; + + public IEnumerableSelect2Iterator(IEnumerable source, Func selector) + { + Debug.Assert(source is not null); + Debug.Assert(selector is not null); + _source = source; + _selector = selector; + } + + private protected override Iterator Clone() => + new IEnumerableSelect2Iterator(_source, _selector); + + public override void Dispose() + { + if (_enumerator is not null) + { + _enumerator.Dispose(); + _enumerator = null; + } + + base.Dispose(); + } + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _enumerator = _source.GetEnumerator(); + _index = -1; + _state = 2; + goto case 2; + case 2: + Debug.Assert(_enumerator is not null); + + if (_enumerator.MoveNext()) + { + _current = _selector(_enumerator.Current, checked(++_index)); + return true; + } + + Dispose(); + break; + } + + return false; + } + + public override IEnumerable Select(Func selector) => + new IEnumerableSelect2Iterator(_source, CombineSelectors(_selector, selector)); + } + /// /// An iterator that maps each item of an array. /// diff --git a/src/libraries/System.Linq/src/System/Linq/Utilities.cs b/src/libraries/System.Linq/src/System/Linq/Utilities.cs index 987d0004d4ec3a..9cc8245e17f1a2 100644 --- a/src/libraries/System.Linq/src/System/Linq/Utilities.cs +++ b/src/libraries/System.Linq/src/System/Linq/Utilities.cs @@ -69,5 +69,19 @@ public static Func CombinePredicates(Func /// public static Func CombineSelectors(Func selector1, Func selector2) => x => selector2(selector1(x)); + + /// + /// Combines two selectors. + /// + /// The type of the first selector's argument. + /// The type of the second selector's argument. + /// The type of the second selector's return value. + /// The first selector to run. + /// The second selector to run. + /// + /// A new selector that represents the composition of the first selector with the second selector. + /// + public static Func CombineSelectors(Func selector1, Func selector2) => + (x, i) => selector2(selector1(x, i)); } }