Fixing potential race condition in ListRepository. Now internally implemented as a concurrent dictionary.

This commit is contained in:
ryanbodrug-microsoft 2020-07-08 15:12:36 -07:00
parent 2c45956030
commit 9ff8246a9d
3 changed files with 121 additions and 20 deletions

View File

@ -3,7 +3,10 @@ using Moq;
using NUnit.Framework; using NUnit.Framework;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks;
using Windows.Media.Capture;
using Wox.Infrastructure.Storage; using Wox.Infrastructure.Storage;
namespace Microsoft.Plugin.Program.UnitTests.Storage namespace Microsoft.Plugin.Program.UnitTests.Storage
@ -17,8 +20,7 @@ namespace Microsoft.Plugin.Program.UnitTests.Storage
{ {
//Arrange //Arrange
var itemName = "originalItem1"; var itemName = "originalItem1";
var mockStorage = new Mock<IStorage<IList<string>>>(); IRepository<string> repository = new ListRepository<string>() { itemName };
IRepository<string> repository = new ListRepository<string>(mockStorage.Object) { itemName };
//Act //Act
var result = repository.Contains(itemName); var result = repository.Contains(itemName);
@ -31,8 +33,7 @@ namespace Microsoft.Plugin.Program.UnitTests.Storage
public void Contains_ShouldReturnTrue_WhenListIsUpdatedWithAdd() public void Contains_ShouldReturnTrue_WhenListIsUpdatedWithAdd()
{ {
//Arrange //Arrange
var mockStorage = new Mock<IStorage<IList<string>>>(); IRepository<string> repository = new ListRepository<string>();
IRepository<string> repository = new ListRepository<string>(mockStorage.Object);
//Act //Act
var itemName = "newItem"; var itemName = "newItem";
@ -48,8 +49,7 @@ namespace Microsoft.Plugin.Program.UnitTests.Storage
{ {
//Arrange //Arrange
var itemName = "originalItem1"; var itemName = "originalItem1";
var mockStorage = new Mock<IStorage<IList<string>>>(); IRepository<string> repository = new ListRepository<string>() { itemName };
IRepository<string> repository = new ListRepository<string>(mockStorage.Object) { itemName };
//Act //Act
repository.Remove(itemName); repository.Remove(itemName);
@ -58,5 +58,91 @@ namespace Microsoft.Plugin.Program.UnitTests.Storage
//Assert //Assert
Assert.IsFalse(result); Assert.IsFalse(result);
} }
[Test]
public async Task Add_ShouldNotThrow_WhenBeingIterated()
{
//Arrange
ListRepository<string> repository = new ListRepository<string>();
var numItems = 1000;
for(var i=0; i<numItems;++i)
{
repository.Add($"OriginalItem_{i}");
}
//Act - Begin iterating on one thread
var iterationTask = Task.Run(() =>
{
var remainingIterations = 10000;
while (remainingIterations > 0)
{
foreach (var item in repository)
{
//keep iterating
}
--remainingIterations;
}
});
//Act - Insert on another thread
var addTask = Task.Run(() =>
{
for (var i = 0; i < numItems; ++i)
{
repository.Add($"NewItem_{i}");
}
});
//Assert that this does not throw. Collections that aren't syncronized will throw an invalidoperatioexception if the list is modified while enumerating
Assert.DoesNotThrowAsync(async () =>
{
await Task.WhenAll(new Task[] { iterationTask, addTask });
});
}
[Test]
public async Task Remove_ShouldNotThrow_WhenBeingIterated()
{
//Arrange
ListRepository<string> repository = new ListRepository<string>();
var numItems = 1000;
for (var i = 0; i < numItems; ++i)
{
repository.Add($"OriginalItem_{i}");
}
//Act - Begin iterating on one thread
var iterationTask = Task.Run(() =>
{
var remainingIterations = 10000;
while (remainingIterations > 0)
{
foreach (var item in repository)
{
//keep iterating
}
--remainingIterations;
}
});
//Act - Remove on another thread
var addTask = Task.Run(() =>
{
for (var i = 0; i < numItems; ++i)
{
repository.Remove($"OriginalItem_{i}");
}
});
//Assert that this does not throw. Collections that aren't syncronized will throw an invalidoperatioexception if the list is modified while enumerating
Assert.DoesNotThrowAsync(async () =>
{
await Task.WhenAll(new Task[] { iterationTask, addTask });
});
}
} }
} }

View File

@ -16,9 +16,12 @@ namespace Microsoft.Plugin.Program.Storage
/// </summary> /// </summary>
internal class PackageRepository : ListRepository<UWP.Application>, IRepository<UWP.Application>, IProgramRepository internal class PackageRepository : ListRepository<UWP.Application>, IRepository<UWP.Application>, IProgramRepository
{ {
IPackageCatalog _packageCatalog; private IStorage<IList<UWP.Application>> _storage;
public PackageRepository(IPackageCatalog packageCatalog, IStorage<IList<UWP.Application>> storage) : base(storage)
private IPackageCatalog _packageCatalog;
public PackageRepository(IPackageCatalog packageCatalog, IStorage<IList<UWP.Application>> storage)
{ {
_storage = storage ?? throw new ArgumentNullException("storage", "StorageRepository requires an initialized storage interface");
_packageCatalog = packageCatalog ?? throw new ArgumentNullException("packageCatalog", "PackageRepository expects an interface to be able to subscribe to package events"); _packageCatalog = packageCatalog ?? throw new ArgumentNullException("packageCatalog", "PackageRepository expects an interface to be able to subscribe to package events");
_packageCatalog.PackageInstalling += OnPackageInstalling; _packageCatalog.PackageInstalling += OnPackageInstalling;
_packageCatalog.PackageUninstalling += OnPackageUninstalling; _packageCatalog.PackageUninstalling += OnPackageUninstalling;
@ -55,7 +58,7 @@ namespace Microsoft.Plugin.Program.Storage
{ {
//find apps associated with this package. //find apps associated with this package.
var uwp = new UWP(args.Package); var uwp = new UWP(args.Package);
var apps = _items.Where(a => a.Package.Equals(uwp)).ToArray(); var apps = Items.Where(a => a.Package.Equals(uwp)).ToArray();
foreach (var app in apps) foreach (var app in apps)
{ {
Remove(app); Remove(app);
@ -74,7 +77,7 @@ namespace Microsoft.Plugin.Program.Storage
public void Save() public void Save()
{ {
_storage.Save(_items); _storage.Save(Items);
} }
public void Load() public void Load()

View File

@ -1,11 +1,14 @@
using NLog.Filters; using NLog.Filters;
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Windows.Controls.Primitives;
using Wox.Infrastructure; using Wox.Infrastructure;
using Wox.Infrastructure.Logger;
namespace Wox.Infrastructure.Storage namespace Wox.Infrastructure.Storage
{ {
@ -16,18 +19,19 @@ namespace Wox.Infrastructure.Storage
/// <typeparam name="T"></typeparam> /// <typeparam name="T"></typeparam>
public class ListRepository<T> : IRepository<T>, IEnumerable<T> public class ListRepository<T> : IRepository<T>, IEnumerable<T>
{ {
protected IList<T> _items = new List<T>(); public IList<T> Items { get { return _items.Values.ToList(); } }
protected IStorage<IList<T>> _storage;
public ListRepository(IStorage<IList<T>> storage) private ConcurrentDictionary<int, T> _items = new ConcurrentDictionary<int, T>();
public ListRepository()
{ {
_storage = storage ?? throw new ArgumentNullException("storage", "StorageRepository requires an initialized storage interface");
} }
public void Set(IList<T> items) public void Set(IList<T> items)
{ {
//enforce that internal representation //enforce that internal representation
_items = items.ToList<T>(); _items = new ConcurrentDictionary<int, T>(items.ToDictionary( i => i.GetHashCode()));
} }
public bool Any() public bool Any()
@ -37,27 +41,35 @@ namespace Wox.Infrastructure.Storage
public void Add(T insertedItem) public void Add(T insertedItem)
{ {
_items.Add(insertedItem); if (!_items.TryAdd(insertedItem.GetHashCode(), insertedItem))
{
Log.Error($"|ListRepository.Add| Item Already Exists <{insertedItem}>");
}
} }
public void Remove(T removedItem) public void Remove(T removedItem)
{ {
_items.Remove(removedItem);
if (!_items.TryRemove(removedItem.GetHashCode(), out _))
{
Log.Error($"|ListRepository.Remove| Item Not Found <{removedItem}>");
}
} }
public ParallelQuery<T> AsParallel() public ParallelQuery<T> AsParallel()
{ {
return _items.AsParallel(); return _items.Values.AsParallel();
} }
public bool Contains(T item) public bool Contains(T item)
{ {
return _items.Contains(item); return _items.ContainsKey(item.GetHashCode());
} }
public IEnumerator<T> GetEnumerator() public IEnumerator<T> GetEnumerator()
{ {
return _items.GetEnumerator(); return _items.Values.GetEnumerator();
} }
IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator()