Use custom TaskScheduler for Framework.RunOnTick (#1597)

* Use custom TaskScheduler for Framework.RunOnTick

* TaskSchedulerWidget: add example
This commit is contained in:
srkizer 2024-03-14 08:36:38 +09:00 committed by GitHub
parent 666feede4c
commit a26bb58fdb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 353 additions and 171 deletions

View file

@ -1,3 +1,4 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
@ -41,11 +42,13 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
[ServiceManager.ServiceDependency]
private readonly DalamudConfiguration configuration = Service<DalamudConfiguration>.Get();
private readonly object runOnNextTickTaskListSync = new();
private List<RunOnNextTickTaskBase> runOnNextTickTaskList = new();
private List<RunOnNextTickTaskBase> runOnNextTickTaskList2 = new();
private readonly CancellationTokenSource frameworkDestroy;
private readonly ThreadBoundTaskScheduler frameworkThreadTaskScheduler;
private Thread? frameworkUpdateThread;
private readonly ConcurrentDictionary<TaskCompletionSource, (ulong Expire, CancellationToken CancellationToken)>
tickDelayedTaskCompletionSources = new();
private ulong tickCounter;
[ServiceManager.ServiceConstructor]
private Framework(TargetSigScanner sigScanner, GameLifecycle lifecycle)
@ -56,6 +59,14 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
this.addressResolver = new FrameworkAddressResolver();
this.addressResolver.Setup(sigScanner);
this.frameworkDestroy = new();
this.frameworkThreadTaskScheduler = new();
this.FrameworkThreadTaskFactory = new(
this.frameworkDestroy.Token,
TaskCreationOptions.None,
TaskContinuationOptions.None,
this.frameworkThreadTaskScheduler);
this.updateHook = Hook<OnUpdateDetour>.FromAddress(this.addressResolver.TickAddress, this.HandleFrameworkUpdate);
this.destroyHook = Hook<OnRealDestroyDelegate>.FromAddress(this.addressResolver.DestroyAddress, this.HandleFrameworkDestroy);
@ -92,14 +103,17 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
/// <inheritdoc/>
public DateTime LastUpdateUTC { get; private set; } = DateTime.MinValue;
/// <inheritdoc/>
public TaskFactory FrameworkThreadTaskFactory { get; }
/// <inheritdoc/>
public TimeSpan UpdateDelta { get; private set; } = TimeSpan.Zero;
/// <inheritdoc/>
public bool IsInFrameworkUpdateThread => Thread.CurrentThread == this.frameworkUpdateThread;
public bool IsInFrameworkUpdateThread => this.frameworkThreadTaskScheduler.IsOnBoundThread;
/// <inheritdoc/>
public bool IsFrameworkUnloading { get; internal set; }
public bool IsFrameworkUnloading => this.frameworkDestroy.IsCancellationRequested;
/// <summary>
/// Gets the list of update sub-delegates that didn't get updated this frame.
@ -111,6 +125,19 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
/// </summary>
internal bool DispatchUpdateEvents { get; set; } = true;
/// <inheritdoc/>
public Task DelayTicks(long numTicks, CancellationToken cancellationToken = default)
{
if (this.frameworkDestroy.IsCancellationRequested)
return Task.FromCanceled(this.frameworkDestroy.Token);
if (numTicks <= 0)
return Task.CompletedTask;
var tcs = new TaskCompletionSource();
this.tickDelayedTaskCompletionSources[tcs] = (this.tickCounter + (ulong)numTicks, cancellationToken);
return tcs.Task;
}
/// <inheritdoc/>
public Task<T> RunOnFrameworkThread<T>(Func<T> func) =>
this.IsInFrameworkUpdateThread || this.IsFrameworkUnloading ? Task.FromResult(func()) : this.RunOnTick(func);
@ -157,20 +184,16 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
return Task.FromCanceled<T>(cts.Token);
}
var tcs = new TaskCompletionSource<T>();
lock (this.runOnNextTickTaskListSync)
{
this.runOnNextTickTaskList.Add(new RunOnNextTickTaskFunc<T>()
if (cancellationToken == default)
cancellationToken = this.FrameworkThreadTaskFactory.CancellationToken;
return this.FrameworkThreadTaskFactory.ContinueWhenAll(
new[]
{
RemainingTicks = delayTicks,
RunAfterTickCount = Environment.TickCount64 + (long)Math.Ceiling(delay.TotalMilliseconds),
CancellationToken = cancellationToken,
TaskCompletionSource = tcs,
Func = func,
});
}
return tcs.Task;
Task.Delay(delay, cancellationToken),
this.DelayTicks(delayTicks, cancellationToken),
},
_ => func(),
cancellationToken);
}
/// <inheritdoc/>
@ -186,20 +209,16 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
return Task.FromCanceled(cts.Token);
}
var tcs = new TaskCompletionSource();
lock (this.runOnNextTickTaskListSync)
{
this.runOnNextTickTaskList.Add(new RunOnNextTickTaskAction()
if (cancellationToken == default)
cancellationToken = this.FrameworkThreadTaskFactory.CancellationToken;
return this.FrameworkThreadTaskFactory.ContinueWhenAll(
new[]
{
RemainingTicks = delayTicks,
RunAfterTickCount = Environment.TickCount64 + (long)Math.Ceiling(delay.TotalMilliseconds),
CancellationToken = cancellationToken,
TaskCompletionSource = tcs,
Action = action,
});
}
return tcs.Task;
Task.Delay(delay, cancellationToken),
this.DelayTicks(delayTicks, cancellationToken),
},
_ => action(),
cancellationToken);
}
/// <inheritdoc/>
@ -215,20 +234,16 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
return Task.FromCanceled<T>(cts.Token);
}
var tcs = new TaskCompletionSource<Task<T>>();
lock (this.runOnNextTickTaskListSync)
{
this.runOnNextTickTaskList.Add(new RunOnNextTickTaskFunc<Task<T>>()
if (cancellationToken == default)
cancellationToken = this.FrameworkThreadTaskFactory.CancellationToken;
return this.FrameworkThreadTaskFactory.ContinueWhenAll(
new[]
{
RemainingTicks = delayTicks,
RunAfterTickCount = Environment.TickCount64 + (long)Math.Ceiling(delay.TotalMilliseconds),
CancellationToken = cancellationToken,
TaskCompletionSource = tcs,
Func = func,
});
}
return tcs.Task.ContinueWith(x => x.Result, cancellationToken).Unwrap();
Task.Delay(delay, cancellationToken),
this.DelayTicks(delayTicks, cancellationToken),
},
_ => func(),
cancellationToken).Unwrap();
}
/// <inheritdoc/>
@ -244,20 +259,16 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
return Task.FromCanceled(cts.Token);
}
var tcs = new TaskCompletionSource<Task>();
lock (this.runOnNextTickTaskListSync)
{
this.runOnNextTickTaskList.Add(new RunOnNextTickTaskFunc<Task>()
if (cancellationToken == default)
cancellationToken = this.FrameworkThreadTaskFactory.CancellationToken;
return this.FrameworkThreadTaskFactory.ContinueWhenAll(
new[]
{
RemainingTicks = delayTicks,
RunAfterTickCount = Environment.TickCount64 + (long)Math.Ceiling(delay.TotalMilliseconds),
CancellationToken = cancellationToken,
TaskCompletionSource = tcs,
Func = func,
});
}
return tcs.Task.ContinueWith(x => x.Result, cancellationToken).Unwrap();
Task.Delay(delay, cancellationToken),
this.DelayTicks(delayTicks, cancellationToken),
},
_ => func(),
cancellationToken).Unwrap();
}
/// <summary>
@ -333,23 +344,9 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
}
}
private void RunPendingTickTasks()
{
if (this.runOnNextTickTaskList.Count == 0 && this.runOnNextTickTaskList2.Count == 0)
return;
for (var i = 0; i < 2; i++)
{
lock (this.runOnNextTickTaskListSync)
(this.runOnNextTickTaskList, this.runOnNextTickTaskList2) = (this.runOnNextTickTaskList2, this.runOnNextTickTaskList);
this.runOnNextTickTaskList2.RemoveAll(x => x.Run());
}
}
private bool HandleFrameworkUpdate(IntPtr framework)
{
this.frameworkUpdateThread ??= Thread.CurrentThread;
this.frameworkThreadTaskScheduler.BoundThread ??= Thread.CurrentThread;
ThreadSafety.MarkMainThread();
@ -381,18 +378,30 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
this.LastUpdate = DateTime.Now;
this.LastUpdateUTC = DateTime.UtcNow;
this.tickCounter++;
foreach (var (k, (expiry, ct)) in this.tickDelayedTaskCompletionSources)
{
if (ct.IsCancellationRequested)
k.SetCanceled(ct);
else if (expiry <= this.tickCounter)
k.SetResult();
else
continue;
this.tickDelayedTaskCompletionSources.Remove(k, out _);
}
if (StatsEnabled)
{
StatsStopwatch.Restart();
this.RunPendingTickTasks();
this.frameworkThreadTaskScheduler.Run();
StatsStopwatch.Stop();
AddToStats(nameof(this.RunPendingTickTasks), StatsStopwatch.Elapsed.TotalMilliseconds);
AddToStats(nameof(this.frameworkThreadTaskScheduler), StatsStopwatch.Elapsed.TotalMilliseconds);
}
else
{
this.RunPendingTickTasks();
this.frameworkThreadTaskScheduler.Run();
}
if (StatsEnabled && this.Update != null)
@ -404,7 +413,7 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
// Cleanup handlers that are no longer being called
foreach (var key in this.NonUpdatedSubDelegates)
{
if (key == nameof(this.RunPendingTickTasks))
if (key == nameof(this.FrameworkThreadTaskFactory))
continue;
if (StatsHistory[key].Count > 0)
@ -431,8 +440,11 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
private bool HandleFrameworkDestroy(IntPtr framework)
{
this.IsFrameworkUnloading = true;
this.frameworkDestroy.Cancel();
this.DispatchUpdateEvents = false;
foreach (var k in this.tickDelayedTaskCompletionSources.Keys)
k.SetCanceled(this.frameworkDestroy.Token);
this.tickDelayedTaskCompletionSources.Clear();
// All the same, for now...
this.lifecycle.SetShuttingDown();
@ -440,95 +452,12 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework
Log.Information("Framework::Destroy!");
Service<Dalamud>.Get().Unload();
this.RunPendingTickTasks();
this.frameworkThreadTaskScheduler.Run();
ServiceManager.WaitForServiceUnload();
Log.Information("Framework::Destroy OK!");
return this.destroyHook.OriginalDisposeSafe(framework);
}
private abstract class RunOnNextTickTaskBase
{
internal int RemainingTicks { get; set; }
internal long RunAfterTickCount { get; init; }
internal CancellationToken CancellationToken { get; init; }
internal bool Run()
{
if (this.CancellationToken.IsCancellationRequested)
{
this.CancelImpl();
return true;
}
if (this.RemainingTicks > 0)
this.RemainingTicks -= 1;
if (this.RemainingTicks > 0)
return false;
if (this.RunAfterTickCount > Environment.TickCount64)
return false;
this.RunImpl();
return true;
}
protected abstract void RunImpl();
protected abstract void CancelImpl();
}
private class RunOnNextTickTaskFunc<T> : RunOnNextTickTaskBase
{
internal TaskCompletionSource<T> TaskCompletionSource { get; init; }
internal Func<T> Func { get; init; }
protected override void RunImpl()
{
try
{
this.TaskCompletionSource.SetResult(this.Func());
}
catch (Exception ex)
{
this.TaskCompletionSource.SetException(ex);
}
}
protected override void CancelImpl()
{
this.TaskCompletionSource.SetCanceled();
}
}
private class RunOnNextTickTaskAction : RunOnNextTickTaskBase
{
internal TaskCompletionSource TaskCompletionSource { get; init; }
internal Action Action { get; init; }
protected override void RunImpl()
{
try
{
this.Action();
this.TaskCompletionSource.SetResult();
}
catch (Exception ex)
{
this.TaskCompletionSource.SetException(ex);
}
}
protected override void CancelImpl()
{
this.TaskCompletionSource.SetCanceled();
}
}
}
/// <summary>
@ -562,6 +491,9 @@ internal class FrameworkPluginScoped : IDisposable, IServiceType, IFramework
/// <inheritdoc/>
public DateTime LastUpdateUTC => this.frameworkService.LastUpdateUTC;
/// <inheritdoc/>
public TaskFactory FrameworkThreadTaskFactory => this.frameworkService.FrameworkThreadTaskFactory;
/// <inheritdoc/>
public TimeSpan UpdateDelta => this.frameworkService.UpdateDelta;
@ -579,6 +511,10 @@ internal class FrameworkPluginScoped : IDisposable, IServiceType, IFramework
this.Update = null;
}
/// <inheritdoc/>
public Task DelayTicks(long numTicks, CancellationToken cancellationToken = default) =>
this.frameworkService.DelayTicks(numTicks, cancellationToken);
/// <inheritdoc/>
public Task<T> RunOnFrameworkThread<T>(Func<T> func)
=> this.frameworkService.RunOnFrameworkThread(func);

View file

@ -1,13 +1,22 @@
// ReSharper disable MethodSupportsCancellation // Using alternative method of cancelling tasks by throwing exceptions.
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Dalamud.Game;
using Dalamud.Interface.Colors;
using Dalamud.Interface.Components;
using Dalamud.Interface.ImGuiFileDialog;
using Dalamud.Interface.Utility;
using Dalamud.Interface.Utility.Raii;
using Dalamud.Logging.Internal;
using Dalamud.Utility;
using ImGuiNET;
using Serilog;
@ -18,6 +27,12 @@ namespace Dalamud.Interface.Internal.Windows.Data.Widgets;
/// </summary>
internal class TaskSchedulerWidget : IDataWindowWidget
{
private readonly FileDialogManager fileDialogManager = new();
private readonly byte[] urlBytes = new byte[2048];
private readonly byte[] localPathBytes = new byte[2048];
private Task? downloadTask = null;
private (long Downloaded, long Total, float Percentage) downloadState;
private CancellationTokenSource taskSchedulerCancelSource = new();
/// <inheritdoc/>
@ -33,11 +48,16 @@ internal class TaskSchedulerWidget : IDataWindowWidget
public void Load()
{
this.Ready = true;
Encoding.UTF8.GetBytes(
"https://geo.mirror.pkgbuild.com/iso/2024.01.01/archlinux-2024.01.01-x86_64.iso",
this.urlBytes);
}
/// <inheritdoc/>
public void Draw()
{
var framework = Service<Framework>.Get();
if (ImGui.Button("Clear list"))
{
TaskTracker.Clear();
@ -84,8 +104,7 @@ internal class TaskSchedulerWidget : IDataWindowWidget
{
Thread.Sleep(200);
string a = null;
a.Contains("dalamud"); // Intentional null exception.
_ = ((string)null)!.Contains("dalamud"); // Intentional null exception.
});
}
@ -94,36 +113,156 @@ internal class TaskSchedulerWidget : IDataWindowWidget
if (ImGui.Button("ASAP"))
{
Task.Run(async () => await Service<Framework>.Get().RunOnTick(() => { }, cancellationToken: this.taskSchedulerCancelSource.Token));
_ = framework.RunOnTick(() => Log.Information("Framework.Update - ASAP"), cancellationToken: this.taskSchedulerCancelSource.Token);
}
ImGui.SameLine();
if (ImGui.Button("In 1s"))
{
Task.Run(async () => await Service<Framework>.Get().RunOnTick(() => { }, cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1)));
_ = framework.RunOnTick(() => Log.Information("Framework.Update - In 1s"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1));
}
ImGui.SameLine();
if (ImGui.Button("In 60f"))
{
Task.Run(async () => await Service<Framework>.Get().RunOnTick(() => { }, cancellationToken: this.taskSchedulerCancelSource.Token, delayTicks: 60));
_ = framework.RunOnTick(() => Log.Information("Framework.Update - In 60f"), cancellationToken: this.taskSchedulerCancelSource.Token, delayTicks: 60);
}
ImGui.SameLine();
if (ImGui.Button("In 1s+120f"))
{
_ = framework.RunOnTick(() => Log.Information("Framework.Update - In 1s+120f"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1), delayTicks: 120);
}
ImGui.SameLine();
if (ImGui.Button("In 2s+60f"))
{
_ = framework.RunOnTick(() => Log.Information("Framework.Update - In 2s+60f"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(2), delayTicks: 60);
}
ImGui.SameLine();
if (ImGui.Button("Every 60 frames"))
{
_ = framework.RunOnTick(
async () =>
{
for (var i = 0L; ; i++)
{
Log.Information($"Loop #{i}; MainThread={ThreadSafety.IsMainThread}");
await framework.DelayTicks(60, this.taskSchedulerCancelSource.Token);
}
},
cancellationToken: this.taskSchedulerCancelSource.Token);
}
ImGui.SameLine();
if (ImGui.Button("Error in 1s"))
{
Task.Run(async () => await Service<Framework>.Get().RunOnTick(() => throw new Exception("Test Exception"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1)));
_ = framework.RunOnTick(() => throw new Exception("Test Exception"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1));
}
ImGui.SameLine();
if (ImGui.Button("As long as it's in Framework Thread"))
{
Task.Run(async () => await Service<Framework>.Get().RunOnFrameworkThread(() => { Log.Information("Task dispatched from non-framework.update thread"); }));
Service<Framework>.Get().RunOnFrameworkThread(() => { Log.Information("Task dispatched from framework.update thread"); }).Wait();
Task.Run(async () => await framework.RunOnFrameworkThread(() => { Log.Information("Task dispatched from non-framework.update thread"); }));
framework.RunOnFrameworkThread(() => { Log.Information("Task dispatched from framework.update thread"); }).Wait();
}
if (ImGui.CollapsingHeader("Download"))
{
ImGui.InputText("URL", this.urlBytes, (uint)this.urlBytes.Length);
ImGui.InputText("Local Path", this.localPathBytes, (uint)this.localPathBytes.Length);
ImGui.SameLine();
if (ImGuiComponents.IconButton("##localpathpicker", FontAwesomeIcon.File))
{
var defaultFileName = Encoding.UTF8.GetString(this.urlBytes).Split('\0', 2)[0].Split('/').Last();
this.fileDialogManager.SaveFileDialog(
"Choose a local path",
"*",
defaultFileName,
string.Empty,
(accept, newPath) =>
{
if (accept)
{
this.localPathBytes.AsSpan().Clear();
Encoding.UTF8.GetBytes(newPath, this.localPathBytes.AsSpan());
}
});
}
ImGui.TextUnformatted($"{this.downloadState.Downloaded:##,###}/{this.downloadState.Total:##,###} ({this.downloadState.Percentage:0.00}%)");
using var disabled =
ImRaii.Disabled(this.downloadTask?.IsCompleted is false || this.localPathBytes[0] == 0);
ImGui.AlignTextToFramePadding();
ImGui.TextUnformatted("Download");
ImGui.SameLine();
var downloadUsingGlobalScheduler = ImGui.Button("using default scheduler");
ImGui.SameLine();
var downloadUsingFramework = ImGui.Button("using Framework.Update");
if (downloadUsingGlobalScheduler || downloadUsingFramework)
{
var url = Encoding.UTF8.GetString(this.urlBytes).Split('\0', 2)[0];
var localPath = Encoding.UTF8.GetString(this.localPathBytes).Split('\0', 2)[0];
var ct = this.taskSchedulerCancelSource.Token;
this.downloadState = default;
var factory = downloadUsingGlobalScheduler
? Task.Factory
: framework.FrameworkThreadTaskFactory;
this.downloadState = default;
this.downloadTask = factory.StartNew(
async () =>
{
try
{
await using var to = File.Create(localPath);
using var client = new HttpClient();
using var conn = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, ct);
this.downloadState.Total = conn.Content.Headers.ContentLength ?? -1L;
await using var from = conn.Content.ReadAsStream(ct);
var buffer = new byte[8192];
while (true)
{
if (downloadUsingFramework)
ThreadSafety.AssertMainThread();
if (downloadUsingGlobalScheduler)
ThreadSafety.AssertNotMainThread();
var len = await from.ReadAsync(buffer, ct);
if (len == 0)
break;
await to.WriteAsync(buffer.AsMemory(0, len), ct);
this.downloadState.Downloaded += len;
if (this.downloadState.Total >= 0)
{
this.downloadState.Percentage =
(100f * this.downloadState.Downloaded) / this.downloadState.Total;
}
}
}
catch (Exception e)
{
Log.Error(e, "Failed to download {from} to {to}.", url, localPath);
try
{
File.Delete(localPath);
}
catch
{
// ignore
}
}
},
cancellationToken: ct).Unwrap();
}
}
if (ImGui.Button("Drown in tasks"))
@ -244,6 +383,8 @@ internal class TaskSchedulerWidget : IDataWindowWidget
ImGui.PopStyleColor(1);
}
this.fileDialogManager.Draw();
}
private async Task TestTaskInTaskDelay(CancellationToken token)

View file

@ -29,6 +29,11 @@ public interface IFramework
/// </summary>
public DateTime LastUpdateUTC { get; }
/// <summary>
/// Gets a <see cref="TaskFactory"/> that runs tasks during Framework Update event.
/// </summary>
public TaskFactory FrameworkThreadTaskFactory { get; }
/// <summary>
/// Gets the delta between the last Framework Update and the currently executing one.
/// </summary>
@ -44,6 +49,14 @@ public interface IFramework
/// </summary>
public bool IsFrameworkUnloading { get; }
/// <summary>
/// Returns a task that completes after the given number of ticks.
/// </summary>
/// <param name="numTicks">Number of ticks to delay.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A new <see cref="Task"/> that gets resolved after specified number of ticks happen.</returns>
public Task DelayTicks(long numTicks, CancellationToken cancellationToken = default);
/// <summary>
/// Run given function right away if this function has been called from game's Framework.Update thread, or otherwise run on next Framework.Update call.
/// </summary>
@ -65,6 +78,7 @@ public interface IFramework
/// <typeparam name="T">Return type.</typeparam>
/// <param name="func">Function to call.</param>
/// <returns>Task representing the pending or already completed function.</returns>
[Obsolete($"Use {nameof(RunOnTick)} instead.")]
public Task<T> RunOnFrameworkThread<T>(Func<Task<T>> func);
/// <summary>
@ -72,6 +86,7 @@ public interface IFramework
/// </summary>
/// <param name="func">Function to call.</param>
/// <returns>Task representing the pending or already completed function.</returns>
[Obsolete($"Use {nameof(RunOnTick)} instead.")]
public Task RunOnFrameworkThread(Func<Task> func);
/// <summary>

View file

@ -0,0 +1,90 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace Dalamud.Utility;
/// <summary>
/// A task scheduler that runs tasks on a specific thread.
/// </summary>
internal class ThreadBoundTaskScheduler : TaskScheduler
{
private const byte Scheduled = 0;
private const byte Running = 1;
private readonly ConcurrentDictionary<Task, byte> scheduledTasks = new();
/// <summary>
/// Initializes a new instance of the <see cref="ThreadBoundTaskScheduler"/> class.
/// </summary>
/// <param name="boundThread">The thread to bind this task scheduelr to.</param>
public ThreadBoundTaskScheduler(Thread? boundThread = null)
{
this.BoundThread = boundThread;
}
/// <summary>
/// Gets or sets the thread this task scheduler is bound to.
/// </summary>
public Thread? BoundThread { get; set; }
/// <summary>
/// Gets a value indicating whether we're on the bound thread.
/// </summary>
public bool IsOnBoundThread => Thread.CurrentThread == this.BoundThread;
/// <summary>
/// Runs queued tasks.
/// </summary>
public void Run()
{
foreach (var task in this.scheduledTasks.Keys)
{
if (!this.scheduledTasks.TryUpdate(task, Running, Scheduled))
continue;
_ = this.TryExecuteTask(task);
}
}
/// <inheritdoc/>
protected override IEnumerable<Task> GetScheduledTasks()
{
return this.scheduledTasks.Keys;
}
/// <inheritdoc/>
protected override void QueueTask(Task task)
{
this.scheduledTasks[task] = Scheduled;
}
/// <inheritdoc/>
protected override bool TryDequeue(Task task)
{
if (!this.scheduledTasks.TryRemove(task, out _))
return false;
return true;
}
/// <inheritdoc/>
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
if (!this.IsOnBoundThread)
return false;
if (taskWasPreviouslyQueued && !this.scheduledTasks.TryUpdate(task, Running, Scheduled))
return false;
_ = this.TryExecuteTask(task);
return true;
}
private new bool TryExecuteTask(Task task)
{
var r = base.TryExecuteTask(task);
this.scheduledTasks.Remove(task, out _);
return r;
}
}