diff --git a/src/System.Collections.Concurrent/src/System.Collections.Concurrent.csproj b/src/System.Collections.Concurrent/src/System.Collections.Concurrent.csproj index d9a57c9345f0..be082dfb68af 100644 --- a/src/System.Collections.Concurrent/src/System.Collections.Concurrent.csproj +++ b/src/System.Collections.Concurrent/src/System.Collections.Concurrent.csproj @@ -5,7 +5,7 @@ {96AA2060-C846-4E56-9509-E8CB9C114C8F} System.Collections.Concurrent System.Collections.Concurrent - FEATURE_TRACING + $(DefineConstants);FEATURE_TRACING true .NETStandard,Version=v1.7 diff --git a/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentBag.cs b/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentBag.cs index 909142333487..e6316a0604e6 100644 --- a/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentBag.cs +++ b/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentBag.cs @@ -2,23 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -// =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ -// -// ConcurrentBag.cs -// -// An unordered collection that allows duplicates and that provides add and get operations. -// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- - using System.Collections.Generic; using System.Diagnostics; -using System.Runtime.InteropServices; using System.Runtime.Serialization; using System.Threading; namespace System.Collections.Concurrent { /// - /// Represents an thread-safe, unordered collection of objects. + /// Represents a thread-safe, unordered collection of objects. /// /// Specifies the type of elements in the bag. /// @@ -41,29 +33,19 @@ namespace System.Collections.Concurrent [Serializable] public class ConcurrentBag : IProducerConsumerCollection, IReadOnlyCollection { - // ThreadLocalList object that contains the data per thread + /// The per-bag, per-thread work-stealing queues. [NonSerialized] - private ThreadLocal _locals; - - // This head and tail pointers points to the first and last local lists, to allow enumeration on the thread locals objects + private ThreadLocal _locals; + /// The head work stealing queue in a linked list of queues. [NonSerialized] - private volatile ThreadLocalList _headList, _tailList; - - // A flag used to tell the operations thread that it must synchronize the operation, this flag is set/unset within - // GlobalListsLock lock - [NonSerialized] - private bool _needSync; - - // Used for custom serialization. + private volatile WorkStealingQueue _workStealingQueues; + /// Temporary storage of the bag's contents used during serialization. private T[] _serializationArray; - /// - /// Initializes a new instance of the - /// class. - /// + /// Initializes a new instance of the class. public ConcurrentBag() { - Initialize(null); + _locals = new ThreadLocal(); } /// @@ -80,26 +62,13 @@ public ConcurrentBag(IEnumerable collection) { throw new ArgumentNullException(nameof(collection), SR.ConcurrentBag_Ctor_ArgumentNullException); } - Initialize(collection); - } - - /// - /// Local helper function to initialize a new bag object - /// - /// An enumeration containing items with which to initialize this bag. - private void Initialize(IEnumerable collection) - { - _locals = new ThreadLocal(); + _locals = new ThreadLocal(); - // Copy the collection to the bag - if (collection != null) + WorkStealingQueue queue = GetCurrentThreadWorkStealingQueue(forceCreate: true); + foreach (T item in collection) { - ThreadLocalList list = GetThreadList(true); - foreach (T item in collection) - { - list.Add(item, false); - } + queue.LocalPush(item); } } @@ -118,21 +87,21 @@ private void OnSerialized(StreamingContext context) _serializationArray = null; } - /// Construct the stack from a previously seiralized one. + /// Construct the stack from a previously serialized one. [OnDeserialized] private void OnDeserialized(StreamingContext context) { - _locals = new ThreadLocal(); + Debug.Assert(_locals == null); + _locals = new ThreadLocal(); - ThreadLocalList list = GetThreadList(true); + WorkStealingQueue queue = GetCurrentThreadWorkStealingQueue(forceCreate: true); foreach (T item in _serializationArray) { - list.Add(item, false); + queue.LocalPush(item); } - _headList = list; - _tailList = list; - _serializationArray = null; + + _workStealingQueues = queue; } /// @@ -141,46 +110,7 @@ private void OnDeserialized(StreamingContext context) /// The object to be added to the /// . The value can be a null reference /// (Nothing in Visual Basic) for reference types. - public void Add(T item) - { - // Get the local list for that thread, create a new list if this thread doesn't exist - //(first time to call add) - ThreadLocalList list = GetThreadList(true); - AddInternal(list, item); - } - - /// - /// - /// - /// - private void AddInternal(ThreadLocalList list, T item) - { - bool lockTaken = false; - try - { -#pragma warning disable 0420 - Interlocked.Exchange(ref list._currentOp, (int)ListOperation.Add); -#pragma warning restore 0420 - //Synchronization cases: - // if the list count is less than two to avoid conflict with any stealing thread - // if _needSync is set, this means there is a thread that needs to freeze the bag - if (list.Count < 2 || _needSync) - { - // reset it back to zero to avoid deadlock with stealing thread - list._currentOp = (int)ListOperation.None; - Monitor.Enter(list, ref lockTaken); - } - list.Add(item, lockTaken); - } - finally - { - list._currentOp = (int)ListOperation.None; - if (lockTaken) - { - Monitor.Exit(list); - } - } - } + public void Add(T item) => GetCurrentThreadWorkStealingQueue(forceCreate: true).LocalPush(item); /// /// Attempts to add an object to the . @@ -196,8 +126,7 @@ bool IProducerConsumerCollection.TryAdd(T item) } /// - /// Attempts to remove and return an object from the . + /// Attempts to remove and return an object from the . /// /// When this method returns, contains the object /// removed from the or the default value @@ -205,12 +134,12 @@ bool IProducerConsumerCollection.TryAdd(T item) /// true if an object was removed successfully; otherwise, false. public bool TryTake(out T result) { - return TryTakeOrPeek(out result, true); + WorkStealingQueue queue = GetCurrentThreadWorkStealingQueue(forceCreate: false); + return (queue != null && queue.TryLocalPop(out result)) || TrySteal(out result, take: true); } /// - /// Attempts to return an object from the - /// without removing it. + /// Attempts to return an object from the without removing it. /// /// When this method returns, contains an object from /// the or the default value of @@ -218,236 +147,108 @@ public bool TryTake(out T result) /// true if and object was returned successfully; otherwise, false. public bool TryPeek(out T result) { - return TryTakeOrPeek(out result, false); + WorkStealingQueue queue = GetCurrentThreadWorkStealingQueue(forceCreate: false); + return (queue != null && queue.TryLocalPeek(out result)) || TrySteal(out result, take: false); } - /// - /// Local helper function to Take or Peek an item from the bag - /// - /// To receive the item retrieved from the bag - /// True means Take operation, false means Peek operation - /// True if succeeded, false otherwise - private bool TryTakeOrPeek(out T result, bool take) - { - // Get the local list for that thread, return null if the thread doesn't exit - //(this thread never add before) - ThreadLocalList list = GetThreadList(false); - if (list == null || list.Count == 0) - { - return Steal(out result, take); - } + /// Gets the work-stealing queue data structure for the current thread. + /// Whether to create a new queue if this thread doesn't have one. + /// The local queue object, or null if the thread doesn't have one. + private WorkStealingQueue GetCurrentThreadWorkStealingQueue(bool forceCreate) => + _locals.Value ?? + (forceCreate ? CreateWorkStealingQueueForCurrentThread() : null); - bool lockTaken = false; - try + private WorkStealingQueue CreateWorkStealingQueueForCurrentThread() + { + lock (GlobalQueuesLock) // necessary to update _workStealingQueues, so as to synchronize with freezing operations { - if (take) // Take operation - { -#pragma warning disable 0420 - Interlocked.Exchange(ref list._currentOp, (int)ListOperation.Take); -#pragma warning restore 0420 - //Synchronization cases: - // if the list count is less than or equal two to avoid conflict with any stealing thread - // if _needSync is set, this means there is a thread that needs to freeze the bag - if (list.Count <= 2 || _needSync) - { - // reset it back to zero to avoid deadlock with stealing thread - list._currentOp = (int)ListOperation.None; - Monitor.Enter(list, ref lockTaken); + WorkStealingQueue head = _workStealingQueues; - // Double check the count and steal if it became empty - if (list.Count == 0) - { - // Release the lock before stealing - if (lockTaken) - { - try { } - finally - { - lockTaken = false; // reset lockTaken to avoid calling Monitor.Exit again in the finally block - Monitor.Exit(list); - } - } - return Steal(out result, true); - } - } - list.Remove(out result); - } - else - { - if (!list.Peek(out result)) - { - return Steal(out result, false); - } - } - } - finally - { - list._currentOp = (int)ListOperation.None; - if (lockTaken) + WorkStealingQueue queue = head != null ? GetUnownedWorkStealingQueue() : null; + if (queue == null) { - Monitor.Exit(list); + _workStealingQueues = queue = new WorkStealingQueue(head); } - } - return true; - } - - - /// - /// Local helper function to retrieve a thread local list by a thread object - /// - /// Create a new list if the thread does not exist - /// The local list object - private ThreadLocalList GetThreadList(bool forceCreate) - { - ThreadLocalList list = _locals.Value; + _locals.Value = queue; - if (list != null) - { - return list; - } - else if (forceCreate) - { - // Acquire the lock to update the _tailList pointer - lock (GlobalListsLock) - { - if (_headList == null) - { - list = new ThreadLocalList(Environment.CurrentManagedThreadId); - _headList = list; - _tailList = list; - } - else - { - list = GetUnownedList(); - if (list == null) - { - list = new ThreadLocalList(Environment.CurrentManagedThreadId); - _tailList._nextList = list; - _tailList = list; - } - } - _locals.Value = list; - } + return queue; } - else - { - return null; - } - Debug.Assert(list != null); - return list; } /// - /// Try to reuse an unowned list if exist - /// unowned lists are the lists that their owner threads are aborted or terminated - /// this is workaround to avoid memory leaks. + /// Try to reuse an unowned queue. If a thread interacts with the bag and then exits, + /// the bag purposefully retains its queue, as it contains data associated with the bag. /// - /// The list object, null if all lists are owned - private ThreadLocalList GetUnownedList() + /// The queue object, or null if no unowned queue could be gathered. + private WorkStealingQueue GetUnownedWorkStealingQueue() { - //the global lock must be held at this point - Debug.Assert(Monitor.IsEntered(GlobalListsLock)); + Debug.Assert(Monitor.IsEntered(GlobalQueuesLock)); + // Look for a thread that has the same ID as this one. It won't have come from the same thread, + // but if our thread ID is reused, we know that no other thread can have the same ID and thus + // no other thread can be using this queue. int currentThreadId = Environment.CurrentManagedThreadId; - ThreadLocalList currentList = _headList; - while (currentList != null) + for (WorkStealingQueue queue = _workStealingQueues; queue != null; queue = queue._nextQueue) { - if (currentList._ownerThreadId == currentThreadId) + if (queue._ownerThreadId == currentThreadId) { - return currentList; + return queue; } - currentList = currentList._nextList; } + return null; } - - /// - /// Local helper method to steal an item from any other non empty thread - /// + /// Local helper method to steal an item from any other non empty thread. /// To receive the item retrieved from the bag /// Whether to remove or peek. /// True if succeeded, false otherwise. - private bool Steal(out T result, bool take) + private bool TrySteal(out T result, bool take) { #if FEATURE_TRACING if (take) + { CDSCollectionETWBCLProvider.Log.ConcurrentBag_TryTakeSteals(); + } else + { CDSCollectionETWBCLProvider.Log.ConcurrentBag_TryPeekSteals(); + } #endif - bool loop; - List versionsList = new List(); // save the lists version - do + // If there's no local queue for this thread, just start from the head queue + // and try to steal from each queue until we get a result. + WorkStealingQueue localQueue = GetCurrentThreadWorkStealingQueue(forceCreate: false); + if (localQueue == null) { - versionsList.Clear(); //clear the list from the previous iteration - loop = false; - - - ThreadLocalList currentList = _headList; - while (currentList != null) - { - versionsList.Add(currentList._version); - if (currentList._head != null && TrySteal(currentList, out result, take)) - { - return true; - } - currentList = currentList._nextList; - } - - // verify versioning, if other items are added to this list since we last visit it, we should retry - currentList = _headList; - foreach (int version in versionsList) - { - if (version != currentList._version) //oops state changed - { - loop = true; - if (currentList._head != null && TrySteal(currentList, out result, take)) - return true; - } - currentList = currentList._nextList; - } - } while (loop); + return TryStealFromTo(_workStealingQueues, null, out result, take); + } + // If there is a local queue from this thread, then start from the next queue + // after it, and then iterate around back from the head to this queue, not including it. + return + TryStealFromTo(localQueue._nextQueue, null, out result, take) || + TryStealFromTo(_workStealingQueues, localQueue, out result, take); - result = default(T); - return false; + // TODO: Investigate storing the queues in an array instead of a linked list, and then + // randomly choosing a starting location from which to start iterating. } /// - /// local helper function tries to steal an item from given local list + /// Attempts to steal from each queue starting from to . /// - private bool TrySteal(ThreadLocalList list, out T result, bool take) + private bool TryStealFromTo(WorkStealingQueue startInclusive, WorkStealingQueue endExclusive, out T result, bool take) { - lock (list) + for (WorkStealingQueue queue = startInclusive; queue != endExclusive; queue = queue._nextQueue) { - if (CanSteal(list)) + if (queue.TrySteal(out result, take)) { - list.Steal(out result, take); return true; } - result = default(T); - return false; - } - } - /// - /// Local helper function to check the list if it became empty after acquiring the lock - /// and wait if there is unsynchronized Add/Take operation in the list to be done - /// - /// The list to steal - /// True if can steal, false otherwise - private static bool CanSteal(ThreadLocalList list) - { - if (list.Count <= 2 && list._currentOp != (int)ListOperation.None) - { - SpinWait spinner = new SpinWait(); - while (list._currentOp != (int)ListOperation.None) - { - spinner.SpinOnce(); - } } - return list.Count > 0; + + result = default(T); + return false; } /// @@ -478,13 +279,14 @@ public void CopyTo(T[] array, int index) } if (index < 0) { - throw new ArgumentOutOfRangeException - (nameof(index), SR.ConcurrentBag_CopyTo_ArgumentOutOfRangeException); + throw new ArgumentOutOfRangeException(nameof(index), SR.ConcurrentBag_CopyTo_ArgumentOutOfRangeException); } // Short path if the bag is empty - if (_headList == null) + if (_workStealingQueues == null) + { return; + } bool lockTaken = false; try @@ -542,7 +344,6 @@ void ICollection.CopyTo(Array array, int index) } } - /// /// Copies the elements to a new array. /// @@ -551,8 +352,10 @@ void ICollection.CopyTo(Array array, int index) public T[] ToArray() { // Short path if the bag is empty - if (_headList == null) + if (_workStealingQueues == null) + { return Array.Empty(); + } bool lockTaken = false; try @@ -581,8 +384,10 @@ public T[] ToArray() public IEnumerator GetEnumerator() { // Short path if the bag is empty - if (_headList == null) + if (_workStealingQueues == null) + { return ((IEnumerable)Array.Empty()).GetEnumerator(); + } bool lockTaken = false; try @@ -607,10 +412,7 @@ public IEnumerator GetEnumerator() /// of the bag. It does not reflect any update to the collection after /// was called. /// - IEnumerator IEnumerable.GetEnumerator() - { - return ((ConcurrentBag)this).GetEnumerator(); - } + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); /// /// Gets the number of elements contained in the . @@ -626,8 +428,10 @@ public int Count get { // Short path if the bag is empty - if (_headList == null) + if (_workStealingQueues == null) + { return 0; + } bool lockTaken = false; try @@ -650,29 +454,27 @@ public bool IsEmpty { get { - if (_headList == null) - return true; - - bool lockTaken = false; - try + if (_workStealingQueues != null) { - FreezeBag(ref lockTaken); - ThreadLocalList currentList = _headList; - while (currentList != null) + bool lockTaken = false; + try { - if (currentList._head != null) - //at least this list is not empty, we return false + FreezeBag(ref lockTaken); + for (WorkStealingQueue queue = _workStealingQueues; queue != null; queue = queue._nextQueue) { - return false; + if (!queue.IsEmpty) + { + return false; + } } - currentList = currentList._nextList; } - return true; - } - finally - { - UnfreezeBag(lockTaken); + finally + { + UnfreezeBag(lockTaken); + } } + + return true; } } @@ -683,10 +485,7 @@ public bool IsEmpty /// true if access to the is synchronized /// with the SyncRoot; otherwise, false. For , this property always /// returns false. - bool ICollection.IsSynchronized - { - get { return false; } - } + bool ICollection.IsSynchronized => false; /// /// Gets an object that can be used to synchronize access to the The SyncRoot property is not supported. object ICollection.SyncRoot { - get - { - throw new NotSupportedException(SR.ConcurrentCollection_SyncRoot_NotSupported); - } + get { throw new NotSupportedException(SR.ConcurrentCollection_SyncRoot_NotSupported); } } - - /// - /// A global lock object, used in two cases: - /// 1- To maintain the _tailList pointer for each new list addition process ( first time a thread called Add ) - /// 2- To freeze the bag in GetEnumerator, CopyTo, ToArray and Count members - /// - private object GlobalListsLock + /// Global lock used to synchronize the queues pointer and all bag-wide operations (e.g. ToArray, Count, etc.). + private object GlobalQueuesLock { get { @@ -716,111 +507,55 @@ private object GlobalListsLock } } - - #region Freeze bag helper methods - /// - /// Local helper method to freeze all bag operations, it - /// 1- Acquire the global lock to prevent any other thread to freeze the bag, and also new thread can be added - /// to the dictionary - /// 2- Then Acquire all local lists locks to prevent steal and synchronized operations - /// 3- Wait for all un-synchronized operations to be done - /// - /// Retrieve the lock taken result for the global lock, to be passed to Unfreeze method + /// "Freezes" the bag, such that no concurrent operations will be mutating the bag when it returns. + /// true if the global lock was taken; otherwise, false. private void FreezeBag(ref bool lockTaken) { - Debug.Assert(!Monitor.IsEntered(GlobalListsLock)); + // Take the global lock to start freezing the bag. This helps, for example, + // to prevent other threads from joining the bag (adding their local queues) + // while a global operation is in progress. + Debug.Assert(!Monitor.IsEntered(GlobalQueuesLock)); + Monitor.Enter(GlobalQueuesLock, ref lockTaken); + WorkStealingQueue head = _workStealingQueues; // stable at least until GlobalQueuesLock is released in UnfreezeBag - // global lock to be safe against multi threads calls count and corrupt _needSync - Monitor.Enter(GlobalListsLock, ref lockTaken); - - // This will force any future add/take operation to be synchronized - _needSync = true; - - //Acquire all local lists locks - AcquireAllLocks(); - - // Wait for all un-synchronized operation to be done - WaitAllOperations(); - } - - /// - /// Local helper method to unfreeze the bag from a frozen state - /// - /// The lock taken result from the Freeze method - private void UnfreezeBag(bool lockTaken) - { - ReleaseAllLocks(); - _needSync = false; - if (lockTaken) + // Then acquire all local queue locks, noting on each that it's been taken. + for (WorkStealingQueue queue = head; queue != null; queue = queue._nextQueue) { - Monitor.Exit(GlobalListsLock); - } - } - - /// - /// local helper method to acquire all local lists locks - /// - private void AcquireAllLocks() - { - Debug.Assert(Monitor.IsEntered(GlobalListsLock)); - - bool lockTaken = false; - ThreadLocalList currentList = _headList; - while (currentList != null) - { - // Try/Finally block to avoid thread abort between acquiring the lock and setting the taken flag - try - { - Monitor.Enter(currentList, ref lockTaken); - } - finally - { - if (lockTaken) - { - currentList._lockTaken = true; - lockTaken = false; - } - } - currentList = currentList._nextList; + Monitor.Enter(queue, ref queue._frozen); } - } + Interlocked.MemoryBarrier(); // prevent reads of _currentOp from moving before writes to _frozen - /// - /// Local helper method to release all local lists locks - /// - private void ReleaseAllLocks() - { - ThreadLocalList currentList = _headList; - while (currentList != null) + // Finally, wait for all unsynchronized operations on each queue to be done. + for (WorkStealingQueue queue = head; queue != null; queue = queue._nextQueue) { - if (currentList._lockTaken) + if (queue._currentOp != (int)Operation.None) { - currentList._lockTaken = false; - Monitor.Exit(currentList); + var spinner = new SpinWait(); + do { spinner.SpinOnce(); } + while (queue._currentOp != (int)Operation.None); } - currentList = currentList._nextList; } } - /// - /// Local helper function to wait all unsynchronized operations - /// - private void WaitAllOperations() + /// "Unfreezes" a bag frozen with . + /// The result of the method. + private void UnfreezeBag(bool lockTaken) { - Debug.Assert(Monitor.IsEntered(GlobalListsLock)); - - ThreadLocalList currentList = _headList; - while (currentList != null) + Debug.Assert(Monitor.IsEntered(GlobalQueuesLock) == lockTaken); + if (lockTaken) { - if (currentList._currentOp != (int)ListOperation.None) + // Release all of the individual queue locks. + for (WorkStealingQueue queue = _workStealingQueues; queue != null; queue = queue._nextQueue) { - SpinWait spinner = new SpinWait(); - while (currentList._currentOp != (int)ListOperation.None) + if (queue._frozen) { - spinner.SpinOnce(); + queue._frozen = false; + Monitor.Exit(queue); } } - currentList = currentList._nextList; + + // Then release the global lock. + Monitor.Exit(GlobalQueuesLock); } } @@ -830,224 +565,360 @@ private void WaitAllOperations() /// The current bag count private int GetCountInternal() { - Debug.Assert(Monitor.IsEntered(GlobalListsLock)); + Debug.Assert(Monitor.IsEntered(GlobalQueuesLock)); int count = 0; - ThreadLocalList currentList = _headList; - while (currentList != null) + for (WorkStealingQueue queue = _workStealingQueues; queue != null; queue = queue._nextQueue) { - checked - { - count += currentList.Count; - } - currentList = currentList._nextList; + checked { count += queue.Count; } } return count; } /// - /// Local helper function to return the bag item in a list, this is mainly used by CopyTo and ToArray + /// Local helper function to return the bag's contents in a list, this is mainly used by CopyTo and ToArray /// This is not thread safe, should be called in Freeze/UnFreeze bag block /// /// List the contains the bag items private List ToList() { - Debug.Assert(Monitor.IsEntered(GlobalListsLock)); + Debug.Assert(Monitor.IsEntered(GlobalQueuesLock)); - List list = new List(); - ThreadLocalList currentList = _headList; - while (currentList != null) + var list = new List(); + + for (WorkStealingQueue queue = _workStealingQueues; queue != null; queue = queue._nextQueue) { - Node currentNode = currentList._head; - while (currentNode != null) - { - list.Add(currentNode._value); - currentNode = currentNode._next; - } - currentList = currentList._nextList; + queue.AddToList(list); } return list; } - #endregion - - - #region Inner Classes - - /// - /// A class that represents a node in the lock thread list - /// - internal class Node + /// Provides a work-stealing queue data structure stored per thread. + private sealed class WorkStealingQueue { - public Node(T value) + /// Initial size of the queue's array. + private const int InitialSize = 32; + /// Starting index for the head and tail indices. + private const int StartIndex = +#if DEBUG + int.MaxValue; // in debug builds, start at the end so we exercise the index reset logic +#else + 0; +#endif + /// Head index from which to steal. This and'd with the is the index into . + private volatile int _headIndex = StartIndex; + /// Tail index at which local pushes/pops happen. This and'd with the is the index into . + private volatile int _tailIndex = StartIndex; + /// The array storing the queue's data. + private volatile T[] _array = new T[InitialSize]; + /// Mask and'd with and to get an index into . + private volatile int _mask = InitialSize - 1; + /// Numbers of elements in the queue from the local perspective; needs to be combined with to get an actual Count. + private int _addTakeCount; + /// Number of steals; needs to be combined with to get an actual Count. + private int _stealCount; + /// The current queue operation. Used to quiesce before performing operations from one thread onto another. + internal volatile int _currentOp; + /// true if this queue's lock is held as part of a global freeze. + internal bool _frozen; + /// Next queue in the 's set of thread-local queues. + internal readonly WorkStealingQueue _nextQueue; + /// Thread ID that owns this queue. + internal readonly int _ownerThreadId; + + /// Initialize the WorkStealingQueue. + /// The next queue in the linked list of work-stealing queues. + internal WorkStealingQueue(WorkStealingQueue nextQueue) { - _value = value; + _ownerThreadId = Environment.CurrentManagedThreadId; + _nextQueue = nextQueue; } - public readonly T _value; - public Node _next; - public Node _prev; - } - /// - /// A class that represents the lock thread list - /// - internal class ThreadLocalList - { - // Head node in the list, null means the list is empty - internal volatile Node _head; + /// Gets whether the queue is empty. + internal bool IsEmpty => _headIndex >= _tailIndex; - // Tail node for the list - private volatile Node _tail; + /// + /// Add new item to the tail of the queue. + /// + /// The item to add. + internal void LocalPush(T item) + { + Debug.Assert(Environment.CurrentManagedThreadId == _ownerThreadId); + bool lockTaken = false; + try + { + // Full fence to ensure subsequent reads don't get reordered before this + Interlocked.Exchange(ref _currentOp, (int)Operation.Add); + int tail = _tailIndex; - // The current list operation - internal volatile int _currentOp; + // Rare corner case (at most once every 2 billion pushes on this thread): + // We're going to increment the tail; if we'll overflow, then we need to reset our counts + if (tail == int.MaxValue) + { + _currentOp = (int)Operation.None; // set back to None temporarily to avoid a deadlock + lock (this) + { + Debug.Assert(_tailIndex == int.MaxValue, "No other thread should be changing _tailIndex"); + + // Rather than resetting to zero, we'll just mask off the bits we don't care about. + // This way we don't need to rearrange the items already in the queue; they'll be found + // correctly exactly where they are. One subtlety here is that we need to make sure that + // if head is currently < tail, it remains that way. This happens to just fall out from + // the bit-masking, because we only do this if tail == int.MaxValue, meaning that all + // bits are set, so all of the bits we're keeping will also be set. Thus it's impossible + // for the head to end up > than the tail, since you can't set any more bits than all of them. + _headIndex = _headIndex & _mask; + _tailIndex = tail = _tailIndex & _mask; + Debug.Assert(_headIndex <= _tailIndex); + + _currentOp = (int)Operation.Add; + } + } - // The list count from the Add/Take perspective - private int _count; + // We'd like to take the fast path that doesn't require locking, if possible. It's not possible if another + // thread is currently requesting that the whole bag synchronize, e.g. a ToArray operation. It's also + // not possible if there are fewer than two spaces available. One space is necessary for obvious reasons: + // to store the element we're trying to push. The other is necessary due to synchronization with steals. + // A stealing thread first increments _headIndex to reserve the slot at its old value, and then tries to + // read from that slot. We could potentially have a race condition whereby _headIndex is incremented just + // before this check, in which case we could overwrite the element being stolen as that slot would appear + // to be empty. Thus, we only allow the fast path if there are two empty slots. + if (!_frozen && tail < (_headIndex + _mask)) + { + _array[tail & _mask] = item; + _tailIndex = tail + 1; + } + else + { + // We need to contend with foreign operations (e.g. steals, enumeration, etc.), so we lock. + _currentOp = (int)Operation.None; // set back to None to avoid a deadlock + Monitor.Enter(this, ref lockTaken); - // The stealing count - internal int _stealCount; + int head = _headIndex; + int count = _tailIndex - _headIndex; - // Next list in the dictionary values - internal volatile ThreadLocalList _nextList; + // If we're full, expand the array. + if (count >= _mask) + { + // Expand the queue by doubling its size. + var newArray = new T[_array.Length << 1]; + int headIdx = head & _mask; + if (headIdx == 0) + { + Array.Copy(_array, 0, newArray, 0, _array.Length); + } + else + { + Array.Copy(_array, headIdx, newArray, 0, _array.Length - headIdx); + Array.Copy(_array, 0, newArray, _array.Length - headIdx, headIdx); + } - // Set if the locl lock is taken - internal bool _lockTaken; + // Reset the field values + _array = newArray; + _headIndex = 0; + _tailIndex = tail = count; + _mask = (_mask << 1) | 1; + } - // The owner thread for this list - internal int _ownerThreadId; + // Add the element + _array[tail & _mask] = item; + _tailIndex = tail + 1; - // the version of the list, incremented only when the list changed from empty to non empty state - internal volatile int _version; + // Update the count to avoid overflow. We can trust _stealCount here, + // as we're inside the lock and it's only manipulated there. + _addTakeCount -= _stealCount; + _stealCount = 0; + } - /// - /// ThreadLocalList constructor - /// - /// The owner thread for this list - internal ThreadLocalList(int ownerThreadId) - { - _ownerThreadId = ownerThreadId; - } - /// - /// Add new item to head of the list - /// - /// The item to add. - /// Whether to update the count. - internal void Add(T item, bool updateCount) - { - checked - { - _count++; - } - Node node = new Node(item); - if (_head == null) - { - Debug.Assert(_tail == null); - _head = node; - _tail = node; - _version++; // changing from empty state to non empty state - } - else - { - node._next = _head; - _head._prev = node; - _head = node; + // Increment the count from the add/take perspective + checked { _addTakeCount++; } } - if (updateCount) // update the count to avoid overflow if this add is synchronized + finally { - _count = _count - _stealCount; - _stealCount = 0; + _currentOp = (int)Operation.None; + if (lockTaken) + { + Monitor.Exit(this); + } } } - /// - /// Remove an item from the head of the list - /// + /// Remove an item from the tail of the queue. /// The removed item - internal void Remove(out T result) + internal bool TryLocalPop(out T result) { - Debug.Assert(_head != null); - Node head = _head; - _head = _head._next; - if (_head != null) + Debug.Assert(Environment.CurrentManagedThreadId == _ownerThreadId); + + int tail = _tailIndex; + if (_headIndex >= tail) + { + result = default(T); + return false; + } + + bool lockTaken = false; + try { - _head._prev = null; + // Decrement the tail using a full fence to ensure subsequent reads don't reorder before this. + // If the read of _headIndex moved before this write to _tailIndex, we could erroneously end up + // popping an element that's concurrently being stolen, leading to the same element being + // dequeued from the bag twice. + _currentOp = (int)Operation.Take; + Interlocked.Exchange(ref _tailIndex, --tail); + + // If there is no interaction with a steal, we can head down the fast path. + // Note that we use _headIndex < tail rather than _headIndex <= tail to account + // for stealing peeks, which don't increment _headIndex, and which could observe + // the written default(T) in a race condition to peek at the element. + if (!_frozen && _headIndex < tail) + { + int idx = tail & _mask; + result = _array[idx]; + _array[idx] = default(T); + _addTakeCount--; + return true; + } + else + { + // Interaction with steals: 0 or 1 elements left. + _currentOp = (int)Operation.None; // set back to None to avoid a deadlock + Monitor.Enter(this, ref lockTaken); + if (_headIndex <= tail) + { + // Element still available. Take it. + int idx = tail & _mask; + result = _array[idx]; + _array[idx] = default(T); + _addTakeCount--; + return true; + } + else + { + // We encountered a race condition and the element was stolen, restore the tail. + _tailIndex = tail + 1; + result = default(T); + return false; + } + } } - else + finally { - _tail = null; + _currentOp = (int)Operation.None; + if (lockTaken) + { + Monitor.Exit(this); + } } - _count--; - result = head._value; } - /// - /// Peek an item from the head of the list - /// + /// Peek an item from the tail of the queue. /// the peeked item /// True if succeeded, false otherwise - internal bool Peek(out T result) + internal bool TryLocalPeek(out T result) { - Node head = _head; - if (head != null) + Debug.Assert(Environment.CurrentManagedThreadId == _ownerThreadId); + + int tail = _tailIndex; + if (_headIndex < tail) { - result = head._value; - return true; + // It is possible to enable lock-free peeks, following the same general approach + // that's used in TryLocalPop. However, peeks are more complicated as we can't + // do the same kind of index reservation that's done in TryLocalPop; doing so could + // end up making a steal think that no item is available, even when one is. To do + // it correctly, then, we'd need to add spinning to TrySteal in case of a concurrent + // peek happening. With a lock, the common case (no contention with steals) will + // effectively only incur two interlocked operations (entering/exiting the lock) instead + // of one (setting Peek as the _currentOp). Combined with Peeks on a bag being rare, + // for now we'll use the simpler/safer code. + lock (this) + { + if (_headIndex < tail) + { + result = _array[(tail - 1) & _mask]; + return true; + } + } } + result = default(T); return false; } - /// - /// Steal an item from the tail of the list - /// + /// Steal an item from the head of the queue. /// the removed item - /// remove or peek flag - internal void Steal(out T result, bool remove) + /// true to take the item; false to simply peek at it + internal bool TrySteal(out T result, bool take) { - Node tail = _tail; - Debug.Assert(tail != null); - if (remove) // Take operation + // Fast-path check to see if the queue is empty. + if (_headIndex < _tailIndex) { - _tail = _tail._prev; - if (_tail != null) + // Anything other than empty requires synchronization. + lock (this) { - _tail._next = null; - } - else - { - _head = null; + int head = _headIndex; + if (take) + { + // Increment head to tentatively take an element: a full fence is used to ensure the read + // of _tailIndex doesn't move earlier, as otherwise we could potentially end up stealing + // the same element that's being popped locally. + Interlocked.Exchange(ref _headIndex, head + 1); + + // If there's an element to steal, do it. + if (head < _tailIndex) + { + int idx = head & _mask; + result = _array[idx]; + _array[idx] = default(T); + _stealCount++; + return true; + } + else + { + // We contended with the local thread and lost the race, so restore the head. + _headIndex = head; + } + } + else if (head < _tailIndex) + { + // Peek, if there's an element available + result = _array[head & _mask]; + return true; + } } - // Increment the steal count - _stealCount++; } - result = tail._value; - } + // The queue was empty. + result = default(T); + return false; + } - /// - /// Gets the total list count, it's not thread safe, may provide incorrect count if it is called concurrently - /// - internal int Count + /// Add the contents of this queue to the specified list. + internal void AddToList(List list) { - get + Debug.Assert(Monitor.IsEntered(this)); + Debug.Assert(_frozen); + + for (int i = _headIndex; i < _tailIndex; i++) { - return _count - _stealCount; + list.Add(_array[i & _mask]); } } - } - #endregion - } - /// - /// List operations for ConcurrentBag - /// - internal enum ListOperation - { - None, - Add, - Take - }; + /// Gets the total number of items in the queue. + /// + /// This is not thread safe, only providing an accurate result either from the owning + /// thread while its lock is held or from any thread while the bag is frozen. + /// + internal int Count => _addTakeCount - _stealCount; + } + /// Lock-free operations performed on a queue. + internal enum Operation + { + None, + Add, + Take + }; + } } diff --git a/src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs b/src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs index 9c3293dc6ace..ac38c7ac4dcd 100644 --- a/src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs +++ b/src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs @@ -21,253 +21,185 @@ public class ConcurrentBagTests : IEnumerable_Generic_Tests protected override EnumerableOrder Order => EnumerableOrder.Unspecified; protected override bool ResetImplemented => true; - [Fact] - public static void TestBasicScenarios() + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(3)] + [InlineData(1000)] + public static void Ctor_InitializeFromCollection_ContainsExpectedItems(int numItems) { - ConcurrentBag cb = new ConcurrentBag(); - Assert.True(cb.IsEmpty); - Task[] tks = new Task[2]; - tks[0] = Task.Run(() => - { - cb.Add(4); - cb.Add(5); - cb.Add(6); - }); - - // Consume the items in the bag - tks[1] = Task.Run(() => - { - int item; - while (!cb.IsEmpty) - { - bool ret = cb.TryTake(out item); - Assert.True(ret); - // loose check - Assert.Contains(item, new[] { 4, 5, 6 }); - } - }); + var expected = new HashSet(Enumerable.Range(0, numItems)); - Task.WaitAll(tks); - } + var bag = new ConcurrentBag(expected); + Assert.Equal(numItems == 0, bag.IsEmpty); + Assert.Equal(expected.Count, bag.Count); - [Fact] - public static void RTest1_Ctor() - { - ConcurrentBag bag = new ConcurrentBag(new int[] { 1, 2, 3 }); - Assert.False(bag.IsEmpty); - Assert.Equal(3, bag.Count); + int item; + var actual = new HashSet(); + for (int i = 0; i < expected.Count; i++) + { + Assert.Equal(expected.Count - i, bag.Count); + Assert.True(bag.TryTake(out item)); + actual.Add(item); + } - Assert.Throws( () => {bag = new ConcurrentBag(null);} ); + Assert.False(bag.TryTake(out item)); + Assert.Equal(0, item); + Assert.True(bag.IsEmpty); + AssertSetsEqual(expected, actual); } [Fact] - public static void RTest2_Add() + public static void Ctor_InvalidArgs_Throws() { - RTest2_Add(1, 10); - RTest2_Add(3, 100); + Assert.Throws("collection", () => new ConcurrentBag(null)); } [Fact] - [OuterLoop] - public static void RTest2_Add01() + public static void Add_TakeFromAnotherThread_ExpectedItemsTaken() { - RTest2_Add(8, 1000); - } + var cb = new ConcurrentBag(); + Assert.True(cb.IsEmpty); + Assert.Equal(0, cb.Count); - [Fact] - public static void RTest3_TakeOrPeek() - { - ConcurrentBag bag = CreateBag(100); - RTest3_TakeOrPeek(bag, 1, 100, true); + const int NumItems = 100000; - bag = CreateBag(100); - RTest3_TakeOrPeek(bag, 4, 10, false); + Task producer = Task.Run(() => Parallel.For(1, NumItems + 1, cb.Add)); - bag = CreateBag(1000); - RTest3_TakeOrPeek(bag, 11, 100, true); - } + var hs = new HashSet(); + while (hs.Count < NumItems) + { + int item; + if (cb.TryTake(out item)) hs.Add(item); + } - [Fact] - public static void RTest4_AddAndTake() - { - RTest4_AddAndTake(8); - RTest4_AddAndTake(16); + producer.GetAwaiter().GetResult(); + + Assert.True(cb.IsEmpty); + Assert.Equal(0, cb.Count); + AssertSetsEqual(new HashSet(Enumerable.Range(1, NumItems)), hs); } - [Fact] - public static void RTest5_CopyTo() + [Theory] + [InlineData(1, 10)] + [InlineData(3, 100)] + [InlineData(8, 1000)] + public static void AddThenPeek_LatestLocalItemRetuned(int threadsCount, int itemsPerThread) { - const int SIZE = 10; - Array array = new int[SIZE]; - int index = 0; + var bag = new ConcurrentBag(); - ConcurrentBag bag = CreateBag(SIZE); - bag.CopyTo((int[])array, index); + using (var b = new Barrier(threadsCount)) + { + WaitAllOrAnyFailed((Enumerable.Range(0, threadsCount).Select(_ => Task.Factory.StartNew(() => + { + b.SignalAndWait(); + for (int i = 1; i < itemsPerThread + 1; i++) + { + bag.Add(i); + int item; + Assert.True(bag.TryPeek(out item)); + Assert.Equal(i, item); + } + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default))).ToArray()); + } - Assert.Throws(() => bag.CopyTo(null, index)); - Assert.Throws(() => bag.CopyTo((int[]) array, -1)); - Assert.Throws(() => bag.CopyTo((int[])array, SIZE)); - Assert.Throws(() => bag.CopyTo((int[])array, SIZE-2)); + Assert.Equal(itemsPerThread * threadsCount, bag.Count); } [Fact] - public static void RTest5_ICollectionCopyTo() + public static void AddOnOneThread_PeekOnAnother_EnsureWeCanTakeOnTheOriginal() { - const int SIZE = 10; - Array array = new int[SIZE]; - int index = 0; - - ConcurrentBag bag = CreateBag(SIZE); - ICollection collection = bag as ICollection; - Assert.NotNull(collection); - collection.CopyTo(array, index); + var bag = new ConcurrentBag(Enumerable.Range(1, 5)); - Assert.Throws(() => collection.CopyTo(null, index)); - Assert.Throws(() => collection.CopyTo((int[])array, -1)); - Assert.Throws(() => collection.CopyTo((int[])array, SIZE)); - Assert.Throws(() => collection.CopyTo((int[])array, SIZE - 2)); - - Array array2 = new int[SIZE, 5]; - Assert.Throws(() => collection.CopyTo(array2, 0)); - } + Task.Factory.StartNew(() => + { + int item; + for (int i = 1; i <= 5; i++) + { + Assert.True(bag.TryPeek(out item)); + Assert.Equal(1, item); + } + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default).GetAwaiter().GetResult(); - /// - /// Test bag addition - /// - /// - /// - /// True if succeeded, false otherwise - private static void RTest2_Add(int threadsCount, int itemsPerThread) - { - int failures = 0; - ConcurrentBag bag = new ConcurrentBag(); + Assert.Equal(5, bag.Count); - Task[] threads = new Task[threadsCount]; - for (int i = 0; i < threads.Length; i++) + for (int i = 5; i > 0; i--) { - threads[i] = Task.Run(() => - { - for (int j = 0; j < itemsPerThread; j++) - { - try - { - bag.Add(j); - int item; - if (!bag.TryPeek(out item) || item != j) - { - Interlocked.Increment(ref failures); - } - } - catch - { - Interlocked.Increment(ref failures); - } - } - }); - } + int item; - Task.WaitAll(threads); + Assert.True(bag.TryPeek(out item)); + Assert.Equal(i, item); // ordering implementation detail that's not guaranteed - Assert.Equal(0, failures); - Assert.Equal(itemsPerThread * threadsCount, bag.Count); + Assert.Equal(i, bag.Count); + Assert.True(bag.TryTake(out item)); + Assert.Equal(i - 1, bag.Count); + Assert.Equal(i, item); // ordering implementation detail that's not guaranteed + } } - /// - /// Test bag Take and Peek operations - /// - /// - /// - /// - /// - /// True if succeeded, false otherwise - private static void RTest3_TakeOrPeek(ConcurrentBag bag, int threadsCount, int itemsPerThread, bool take) + [Theory] + [InlineData(100, 1, 100, true)] + [InlineData(100, 4, 10, false)] + [InlineData(1000, 11, 100, true)] + [InlineData(100000, 2, 50000, true)] + public static void Initialize_ThenTakeOrPeekInParallel_ItemsObtainedAsExpected(int numStartingItems, int threadsCount, int itemsPerThread, bool take) { - int bagCount = bag.Count; - int succeeded = 0; - int failures = 0; - Task[] threads = new Task[threadsCount]; - for (int i = 0; i < threads.Length; i++) + var bag = new ConcurrentBag(Enumerable.Range(1, numStartingItems)); + int successes = 0; + + using (var b = new Barrier(threadsCount)) { - threads[i] = Task.Run(() => + WaitAllOrAnyFailed(Enumerable.Range(0, threadsCount).Select(threadNum => Task.Factory.StartNew(() => { + b.SignalAndWait(); for (int j = 0; j < itemsPerThread; j++) { int data; - bool result = false; - if (take) - { - result = bag.TryTake(out data); - } - else - { - result = bag.TryPeek(out data); - } - - if (result) - { - Interlocked.Increment(ref succeeded); - } - else + if (take ? bag.TryTake(out data) : bag.TryPeek(out data)) { - Interlocked.Increment(ref failures); + Interlocked.Increment(ref successes); + Assert.NotEqual(0, data); // shouldn't be default(T) } } - }); + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default)).ToArray()); } - Task.WaitAll(threads); - - if (take) - { - Assert.Equal(bagCount - succeeded, bag.Count); - } - else - { - Assert.Equal(0, failures); - } + Assert.Equal( + take ? numStartingItems : threadsCount * itemsPerThread, + successes); } - /// - /// Test parallel Add/Take, insert unique elements in the bag, and each element should be removed once - /// - /// - /// True if succeeded, false otherwise - private static void RTest4_AddAndTake(int threadsCount) + [Theory] + [InlineData(8)] + [InlineData(16)] + public static void AddAndTake_ExpectedValuesTransferred(int threadsCount) { - ConcurrentBag bag = new ConcurrentBag(); + var bag = new ConcurrentBag(); - Task[] threads = new Task[threadsCount]; int start = 0; int end = 10; + Task[] threads = new Task[threadsCount]; int[] validation = new int[(end - start) * threads.Length / 2]; for (int i = 0; i < threads.Length; i += 2) { - Interval v = new Interval(start, end); - threads[i] = Task.Factory.StartNew( - (o) => - { - Interval n = (Interval)o; - Add(bag, n.m_start, n.m_end); - }, v, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); - - threads[i + 1] = Task.Run(() => Take(bag, end - start - 1, validation)); + int localStart = start, localEnd = end; + threads[i] = Task.Run(() => AddRange(bag, localStart, localEnd)); + threads[i + 1] = Task.Run(() => TakeRange(bag, localEnd - localStart - 1, validation)); int step = end - start; start = end; end += step; } - - Task.WaitAll(threads); - - int value = -1; + WaitAllOrAnyFailed(threads); //validation + int value = -1; for (int i = 0; i < validation.Length; i++) { if (validation[i] == 0) { - Assert.True(bag.TryTake(out value), String.Format("Add/Take failed, the list is not empty and TryTake returned false; thread count={0}", threadsCount)); + Assert.True(bag.TryTake(out value)); } else { @@ -275,174 +207,625 @@ private static void RTest4_AddAndTake(int threadsCount) } } - Assert.False(bag.Count > 0 || bag.TryTake(out value), String.Format("Add/Take failed, this list is not empty after all remove operations; thread count={0}", threadsCount)); + Assert.False(bag.Count > 0 || bag.TryTake(out value)); + } + + [Fact] + public static void AddFromMultipleThreads_ItemsRemainAfterThreadsGoAway() + { + var bag = new ConcurrentBag(); + + for (int i = 0; i < 1000; i += 100) + { + // Create a thread that adds items to the bag + Task.Factory.StartNew(() => + { + for (int j = i; j < i + 100; j++) + { + bag.Add(j); + } + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default).GetAwaiter().GetResult(); + + // Allow threads to be collected + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + } + + AssertSetsEqual(new HashSet(Enumerable.Range(0, 1000)), new HashSet(bag)); + } + + [Fact] + public static void AddManyItems_ThenTakeOnSameThread_ItemsOutputInExpectedOrder() + { + var bag = new ConcurrentBag(Enumerable.Range(0, 100000)); + for (int i = 99999; i >= 0; --i) + { + int item; + Assert.True(bag.TryTake(out item)); + Assert.Equal(i, item); // Testing an implementation detail rather than guaranteed ordering + } + } + + [Fact] + public static void AddManyItems_ThenTakeOnDifferentThread_ItemsOutputInExpectedOrder() + { + var bag = new ConcurrentBag(Enumerable.Range(0, 100000)); + Task.Factory.StartNew(() => + { + for (int i = 0; i < 100000; i++) + { + int item; + Assert.True(bag.TryTake(out item)); + Assert.Equal(i, item); // Testing an implementation detail rather than guaranteed ordering + } + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default).GetAwaiter().GetResult(); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(10)] + [InlineData(33)] + public static void IterativelyAddOnOneThreadThenTakeOnAnother_OrderMaintained(int initialCount) + { + var bag = new ConcurrentBag(Enumerable.Range(0, initialCount)); + + const int Iterations = 100; + using (AutoResetEvent itemConsumed = new AutoResetEvent(false), itemProduced = new AutoResetEvent(false)) + { + Task t = Task.Run(() => + { + for (int i = 0; i < Iterations; i++) + { + itemProduced.WaitOne(); + int item; + Assert.True(bag.TryTake(out item)); + Assert.Equal(i, item); // Testing an implementation detail rather than guaranteed ordering + itemConsumed.Set(); + } + }); + + for (int i = initialCount; i < Iterations + initialCount; i++) + { + bag.Add(i); + itemProduced.Set(); + itemConsumed.WaitOne(); + } + + t.GetAwaiter().GetResult(); + } + + Assert.Equal(initialCount, bag.Count); } [Fact] - public static void RTest6_GetEnumerator() + public static void Peek_SucceedsOnEmptyBagThatWasOnceNonEmpty() { - ConcurrentBag bag = new ConcurrentBag(); + var bag = new ConcurrentBag(); + int item; - // Empty bag should not enumerate - Assert.Empty(bag); + Assert.False(bag.TryPeek(out item)); + Assert.Equal(0, item); + + bag.Add(42); + for (int i = 0; i < 2; i++) + { + Assert.True(bag.TryPeek(out item)); + Assert.Equal(42, item); + } + + Assert.True(bag.TryTake(out item)); + Assert.Equal(42, item); + + Assert.False(bag.TryPeek(out item)); + Assert.Equal(0, item); + } + + [Fact] + public static void CopyTo_Empty_NothingCopied() + { + var bag = new ConcurrentBag(); + bag.CopyTo(new int[0], 0); + bag.CopyTo(new int[10], 10); + } + + [Fact] + public static void CopyTo_ExpectedElementsCopied() + { + const int Size = 10; + int[] dest; + + // Copy to array in which data fits perfectly + dest = new int[Size]; + var bag = new ConcurrentBag(Enumerable.Range(0, Size)); + bag.CopyTo(dest, 0); + Assert.Equal(Enumerable.Range(0, Size), dest.OrderBy(i => i)); + + // Copy to non-0 index in array where the data fits + dest = new int[Size + 3]; + bag = new ConcurrentBag(Enumerable.Range(0, Size)); + bag.CopyTo(dest, 1); + var results = new int[Size]; + Array.Copy(dest, 1, results, 0, results.Length); + Assert.Equal(Enumerable.Range(0, Size), results.OrderBy(i => i)); + } + + [Fact] + public static void CopyTo_InvalidArgs_Throws() + { + var bag = new ConcurrentBag(Enumerable.Range(0, 10)); + int[] dest = new int[10]; + + Assert.Throws("array", () => bag.CopyTo(null, 0)); + Assert.Throws("index", () => bag.CopyTo(dest, -1)); + Assert.Throws(() => bag.CopyTo(dest, dest.Length)); + Assert.Throws(() => bag.CopyTo(dest, dest.Length - 2)); + } + + [Fact] + public static void ICollectionCopyTo_ExpectedElementsCopied() + { + const int Size = 10; + int[] dest; + + // Copy to array in which data fits perfectly + dest = new int[Size]; + ICollection c = new ConcurrentBag(Enumerable.Range(0, Size)); + c.CopyTo(dest, 0); + Assert.Equal(Enumerable.Range(0, Size), dest.OrderBy(i => i)); + + // Copy to non-0 index in array where the data fits + dest = new int[Size + 3]; + c = new ConcurrentBag(Enumerable.Range(0, Size)); + c.CopyTo(dest, 1); + var results = new int[Size]; + Array.Copy(dest, 1, results, 0, results.Length); + Assert.Equal(Enumerable.Range(0, Size), results.OrderBy(i => i)); + } + + [Fact] + public static void ICollectionCopyTo_InvalidArgs_Throws() + { + ICollection bag = new ConcurrentBag(Enumerable.Range(0, 10)); + Array dest = new int[10]; - for (int i = 0; i < 100; i++) + Assert.Throws("array", () => bag.CopyTo(null, 0)); + Assert.Throws("dstIndex", () => bag.CopyTo(dest, -1)); + Assert.Throws(() => bag.CopyTo(dest, dest.Length)); + Assert.Throws(() => bag.CopyTo(dest, dest.Length - 2)); + } + + [Theory] + [InlineData(0, true)] + [InlineData(1, true)] + [InlineData(1, false)] + [InlineData(10, true)] + [InlineData(100, true)] + [InlineData(100, false)] + public static async Task GetEnumerator_Generic_ExpectedElementsYielded(int numItems, bool consumeFromSameThread) + { + var bag = new ConcurrentBag(); + using (var e = bag.GetEnumerator()) + { + Assert.False(e.MoveNext()); + } + + // Add, and validate enumeration after each item added + for (int i = 1; i <= numItems; i++) { bag.Add(i); + Assert.Equal(i, bag.Count); + Assert.Equal(i, bag.Distinct().Count()); } - int count = 0; - foreach (int x in bag) + // Take, and validate enumerate after each item removed. + Action consume = () => { - count++; + for (int i = 1; i <= numItems; i++) + { + int item; + Assert.True(bag.TryTake(out item)); + Assert.Equal(numItems - i, bag.Count); + Assert.Equal(numItems - i, bag.Distinct().Count()); + } + }; + if (consumeFromSameThread) + { + consume(); } + else + { + await Task.Factory.StartNew(consume, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default); + } + } + + [Fact] + public static void GetEnumerator_NonGeneric() + { + var bag = new ConcurrentBag(); + var ebag = (IEnumerable)bag; + + Assert.False(ebag.GetEnumerator().MoveNext()); - Assert.Equal(count, bag.Count); + bag.Add(42); + bag.Add(84); + + var hs = new HashSet(ebag.Cast()); + Assert.Equal(2, hs.Count); + Assert.Contains(42, hs); + Assert.Contains(84, hs); } [Fact] - public static void RTest7_BugFix575975() + public static void GetEnumerator_EnumerationsAreSnapshots() { - BlockingCollection bc = new BlockingCollection(new ConcurrentBag()); - bool succeeded = true; - Task[] threads = new Task[4]; - for (int t = 0; t < threads.Length; t++) + var bag = new ConcurrentBag(); + Assert.Empty(bag); + + using (IEnumerator e1 = bag.GetEnumerator()) { - threads[t] = Task.Factory.StartNew((obj) => + bag.Add(1); + using (IEnumerator e2 = bag.GetEnumerator()) { - int index = (int)obj; - for (int i = 0; i < 100000; i++) + bag.Add(2); + using (IEnumerator e3 = bag.GetEnumerator()) { - if (index < threads.Length / 2) - { - int k = 0; - for (int j = 0; j < 1000; j++) - k++; - bc.Add(i); - } - else + int item; + Assert.True(bag.TryTake(out item)); + using (IEnumerator e4 = bag.GetEnumerator()) { - try - { - bc.Take(); - } - catch // Take must not fail - { - succeeded = false; - break; - } + Assert.False(e1.MoveNext()); + + Assert.True(e2.MoveNext()); + Assert.False(e2.MoveNext()); + + Assert.True(e3.MoveNext()); + Assert.True(e3.MoveNext()); + Assert.False(e3.MoveNext()); + + Assert.True(e4.MoveNext()); + Assert.False(e4.MoveNext()); } } - - }, t, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } } + } + + [Theory] + [InlineData(100, 1, 10)] + [InlineData(4, 100000, 10)] + public static void BlockingCollection_WrappingBag_ExpectedElementsTransferred(int numThreadsPerConsumerProducer, int numItemsPerThread, int producerSpin) + { + var bc = new BlockingCollection(new ConcurrentBag()); + long dummy = 0; + + Task[] producers = Enumerable.Range(0, numThreadsPerConsumerProducer).Select(_ => Task.Factory.StartNew(() => + { + for (int i = 1; i <= numItemsPerThread; i++) + { + for (int j = 0; j < producerSpin; j++) dummy *= j; // spin a little + bc.Add(i); + } + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default)).ToArray(); + + Task[] consumers = Enumerable.Range(0, numThreadsPerConsumerProducer).Select(_ => Task.Factory.StartNew(() => + { + for (int i = 0; i < numItemsPerThread; i++) + { + const int TimeoutMs = 100000; + int item; + Assert.True(bc.TryTake(out item, TimeoutMs), $"Couldn't get {i}th item after {TimeoutMs}ms"); + Assert.NotEqual(0, item); + } + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default)).ToArray(); - Task.WaitAll(threads); - Assert.True(succeeded); + WaitAllOrAnyFailed(producers); + WaitAllOrAnyFailed(consumers); } [Fact] - public static void RTest8_Interfaces() - { - ConcurrentBag bag = new ConcurrentBag(); - //IPCC - IProducerConsumerCollection ipcc = bag as IProducerConsumerCollection; - Assert.False(ipcc == null, "RTest8_Interfaces: ConcurrentBag doesn't implement IPCC"); - Assert.True(ipcc.TryAdd(1), "RTest8_Interfaces: IPCC.TryAdd failed"); - Assert.Equal(1, bag.Count); - - int result = -1; - Assert.True(ipcc.TryTake(out result), "RTest8_Interfaces: IPCC.TryTake failed"); - Assert.True(1 == result, "RTest8_Interfaces: IPCC.TryTake failed"); - Assert.Equal(0, bag.Count); + public static void IProducerConsumerCollection_TryAdd_TryTake_ToArray() + { + IProducerConsumerCollection bag = new ConcurrentBag(); - //ICollection - ICollection collection = bag as ICollection; - Assert.False(collection == null, "RTest8_Interfaces: ConcurrentBag doesn't implement ICollection"); - Assert.False(collection.IsSynchronized, "RTest8_Interfaces: IsSynchronized returned true"); + Assert.True(bag.TryAdd(42)); + Assert.Equal(new[] { 42 }, bag.ToArray()); - //IEnumerable - IEnumerable enumerable = bag as IEnumerable; - Assert.False(enumerable == null, "RTest8_Interfaces: ConcurrentBag doesn't implement IEnumerable"); - // Empty bag shouldn't enumerate. - Assert.Empty(enumerable); + Assert.True(bag.TryAdd(84)); + Assert.Equal(new[] { 42, 84 }, bag.ToArray().OrderBy(i => i)); + + int item; + + Assert.True(bag.TryTake(out item)); + int remainingItem = item == 42 ? 84 : 42; + Assert.Equal(new[] { remainingItem }, bag.ToArray()); + Assert.True(bag.TryTake(out item)); + Assert.Equal(remainingItem, item); + Assert.Empty(bag.ToArray()); } [Fact] - public static void RTest8_Interfaces_Negative() + public static void ICollection_IsSynchronized_SyncRoot() { - ConcurrentBag bag = new ConcurrentBag(); - //IPCC - IProducerConsumerCollection ipcc = bag as IProducerConsumerCollection; - ICollection collection = bag as ICollection; - Assert.Throws(() => { object obj = collection.SyncRoot; }); + ICollection bag = new ConcurrentBag(); + Assert.False(bag.IsSynchronized); + Assert.Throws(() => bag.SyncRoot); } [Fact] - public static void RTest9_ToArray() + public static void ToArray_ParallelInvocations_Succeed() { var bag = new ConcurrentBag(); + Assert.Empty(bag.ToArray()); + + const int NumItems = 10000; + + Parallel.For(0, NumItems, bag.Add); + Assert.Equal(NumItems, bag.Count); + + Parallel.For(0, 10, i => + { + var hs = new HashSet(bag.ToArray()); + Assert.Equal(NumItems, hs.Count); + }); + } + + [OuterLoop("Runs for several seconds")] + [Fact] + public static void ManyConcurrentAddsTakes_BagRemainsConsistent_LongRunning() => + ManyConcurrentAddsTakes_BagRemainsConsistent(3.0); + + [Theory] + [InlineData(0.5)] + public static void ManyConcurrentAddsTakes_BagRemainsConsistent(double seconds) + { + var bag = new ConcurrentBag(); + + DateTime end = DateTime.UtcNow + TimeSpan.FromSeconds(seconds); + + // Thread that adds + Task> adds = Task.Run(() => + { + var added = new HashSet(); + long i = long.MinValue; + while (DateTime.UtcNow < end) + { + i++; + bag.Add(i); + added.Add(i); + } + return added; + }); + + // Thread that adds and takes + Task,HashSet>> addsAndTakes = Task.Run(() => + { + var added = new HashSet(); + var taken = new HashSet(); + + long i = 1; // avoid 0 as default(T), to detect accidentally reading a default value + while (DateTime.UtcNow < end) + { + i++; + bag.Add(i); + added.Add(i); - Assert.NotNull(bag.ToArray()); - Assert.Equal(0, bag.ToArray().Length); + long item; + if (bag.TryTake(out item)) + { + Assert.NotEqual(0, item); + taken.Add(item); + } + } - int[] allItems = new int[10000]; - for (int i = 0; i < allItems.Length; i++) - allItems[i] = i; + return new KeyValuePair, HashSet>(added, taken); + }); - bag = new ConcurrentBag(allItems); - int failCount = 0; - Task[] tasks = new Task[10]; - for (int i = 0; i < tasks.Length; i++) + // Thread that just takes + Task> takes = Task.Run(() => { - tasks[i] = Task.Run(() => + var taken = new HashSet(); + while (DateTime.UtcNow < end) + { + long item; + if (bag.TryTake(out item)) { - int[] array = bag.ToArray(); - if (array == null || array.Length != 10000) - Interlocked.Increment(ref failCount); - }); - } + Assert.NotEqual(0, item); + taken.Add(item); + } + } + return taken; + }); + + // Wait for them all to finish + WaitAllOrAnyFailed(adds, addsAndTakes, takes); - Task.WaitAll(tasks); - Assert.True(0 == failCount, "RTest9_ToArray: One or more thread failed to get the correct bag items from ToArray"); + // Combine everything they added and remove everything they took + var total = new HashSet(adds.Result); + total.UnionWith(addsAndTakes.Result.Key); + total.ExceptWith(addsAndTakes.Result.Value); + total.ExceptWith(takes.Result); + + // What's left should match what's in the bag + Assert.Equal(total.OrderBy(i => i), bag.OrderBy(i => i)); } + [OuterLoop("Runs for several seconds")] [Fact] - public static void RTest10_DebuggerAttributes() + public static void ManyConcurrentAddsTakesPeeks_ForceContentionWithSteals_LongRunning() => + ManyConcurrentAddsTakesPeeks_ForceContentionWithSteals(3.0); + + [Theory] + [InlineData(0.5)] + public static void ManyConcurrentAddsTakesPeeks_ForceContentionWithSteals(double seconds) { - DebuggerAttributes.ValidateDebuggerDisplayReferences(new ConcurrentBag()); - DebuggerAttributes.ValidateDebuggerTypeProxyProperties(new ConcurrentBag()); + var bag = new ConcurrentBag(); + const int MaxCount = 4; + + DateTime end = DateTime.UtcNow + TimeSpan.FromSeconds(seconds); + + Task addsTakes = Task.Factory.StartNew(() => + { + long total = 0; + while (DateTime.UtcNow < end) + { + for (int i = 1; i <= MaxCount; i++) + { + bag.Add(i); + total++; + } + + int item; + if (bag.TryPeek(out item)) + { + Assert.InRange(item, 1, MaxCount); + } + + for (int i = 1; i <= MaxCount; i++) + { + if (bag.TryTake(out item)) + { + total--; + Assert.InRange(item, 1, MaxCount); + } + } + } + return total; + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default); + + Task steals = Task.Factory.StartNew(() => + { + long total = 0; + int item; + while (DateTime.UtcNow < end) + { + if (bag.TryTake(out item)) + { + total++; + Assert.InRange(item, 1, MaxCount); + } + } + return total; + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default); + + WaitAllOrAnyFailed(addsTakes, steals); + long remaining = addsTakes.Result - steals.Result; + Assert.InRange(remaining, 0, long.MaxValue); + Assert.Equal(remaining, bag.Count); } - #region Helper Methods / Classes + [OuterLoop("Runs for several seconds")] + [Fact] + public static void ManyConcurrentAddsTakesPeeks_ForceContentionWithStealingPeeks_LongRunning() => + ManyConcurrentAddsTakesPeeks_ForceContentionWithStealingPeeks(3.0); - private struct Interval + [Theory] + [InlineData(0.5)] + public static void ManyConcurrentAddsTakesPeeks_ForceContentionWithStealingPeeks(double seconds) { - public Interval(int start, int end) + var bag = new ConcurrentBag(); + const int MaxCount = 4; + + DateTime end = DateTime.UtcNow + TimeSpan.FromSeconds(seconds); + + Task addsTakes = Task.Factory.StartNew(() => { - m_start = start; - m_end = end; - } - internal int m_start; - internal int m_end; + long total = 0; + while (DateTime.UtcNow < end) + { + for (int i = 1; i <= MaxCount; i++) + { + bag.Add(i); + total++; + } + + int item; + Assert.True(bag.TryPeek(out item)); + Assert.Equal(MaxCount, item); + + for (int i = 1; i <= MaxCount; i++) + { + if (bag.TryTake(out item)) + { + total--; + Assert.InRange(item, 1, MaxCount); + } + } + } + return total; + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default); + + Task steals = Task.Factory.StartNew(() => + { + int item; + while (DateTime.UtcNow < end) + { + if (bag.TryPeek(out item)) + { + Assert.InRange(item, 1, MaxCount); + } + } + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default); + + WaitAllOrAnyFailed(addsTakes, steals); + Assert.Equal(0, addsTakes.Result); + Assert.Equal(0, bag.Count); } - /// - /// Create a ComcurrentBag object - /// - /// number of the elements in the bag - /// The bag object - private static ConcurrentBag CreateBag(int numbers) + [OuterLoop("Runs for several seconds")] + [Fact] + public static void ManyConcurrentAddsTakes_ForceContentionWithFreezing_LongRunning() => + ManyConcurrentAddsTakes_ForceContentionWithFreezing(3.0); + + [Theory] + [InlineData(0.5)] + public static void ManyConcurrentAddsTakes_ForceContentionWithFreezing(double seconds) { - ConcurrentBag bag = new ConcurrentBag(); - for (int i = 0; i < numbers; i++) + var bag = new ConcurrentBag(); + const int MaxCount = 4; + + DateTime end = DateTime.UtcNow + TimeSpan.FromSeconds(seconds); + + Task addsTakes = Task.Factory.StartNew(() => { - bag.Add(i); + while (DateTime.UtcNow < end) + { + for (int i = 1; i <= MaxCount; i++) + { + bag.Add(i); + } + for (int i = 1; i <= MaxCount; i++) + { + int item; + Assert.True(bag.TryTake(out item)); + Assert.InRange(item, 1, MaxCount); + } + } + }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default); + + while (DateTime.UtcNow < end) + { + int[] arr = bag.ToArray(); + Assert.InRange(arr.Length, 0, MaxCount); + Assert.DoesNotContain(0, arr); // make sure we didn't get default(T) } - return bag; + + addsTakes.GetAwaiter().GetResult(); + Assert.Equal(0, bag.Count); + } + + [Fact] + public static void ValidateDebuggerAttributes() + { + DebuggerAttributes.ValidateDebuggerDisplayReferences(new ConcurrentBag()); + DebuggerAttributes.ValidateDebuggerTypeProxyProperties(new ConcurrentBag()); + + DebuggerAttributes.ValidateDebuggerDisplayReferences(new ConcurrentBag(Enumerable.Range(0, 10))); + DebuggerAttributes.ValidateDebuggerTypeProxyProperties(new ConcurrentBag(Enumerable.Range(0, 10))); } - private static void Add(ConcurrentBag bag, int start, int end) + private static void AddRange(ConcurrentBag bag, int start, int end) { for (int i = start; i < end; i++) { @@ -450,19 +833,56 @@ private static void Add(ConcurrentBag bag, int start, int end) } } - private static void Take(ConcurrentBag bag, int count, int[] validation) + private static void TakeRange(ConcurrentBag bag, int count, int[] validation) { for (int i = 0; i < count; i++) { - int value = -1; - - if (bag.TryTake(out value) && validation != null) + int value; + if (bag.TryTake(out value)) { Interlocked.Increment(ref validation[value]); } } } - #endregion + private static void AssertSetsEqual(HashSet expected, HashSet actual) + { + Assert.Equal(expected.Count, actual.Count); + Assert.Subset(expected, actual); + Assert.Subset(actual, expected); + } + + private static void WaitAllOrAnyFailed(params Task[] tasks) + { + if (tasks.Length == 0) + { + return; + } + + int remaining = tasks.Length; + var mres = new ManualResetEventSlim(); + + foreach (Task task in tasks) + { + task.ContinueWith(t => + { + if (Interlocked.Decrement(ref remaining) == 0 || t.IsFaulted) + { + mres.Set(); + } + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + } + + mres.Wait(); + + // Either all tasks are completed or at least one failed + foreach (Task t in tasks) + { + if (t.IsFaulted) + { + t.GetAwaiter().GetResult(); // propagate for the first one that failed + } + } + } } }