From a26bb58fdbb79032b26b85488a04d49d0e40b8d0 Mon Sep 17 00:00:00 2001 From: srkizer Date: Thu, 14 Mar 2024 08:36:38 +0900 Subject: [PATCH] Use custom TaskScheduler for Framework.RunOnTick (#1597) * Use custom TaskScheduler for Framework.RunOnTick * TaskSchedulerWidget: add example --- Dalamud/Game/Framework.cs | 262 +++++++----------- .../Data/Widgets/TaskSchedulerWidget.cs | 157 ++++++++++- Dalamud/Plugin/Services/IFramework.cs | 15 + Dalamud/Utility/ThreadBoundTaskScheduler.cs | 90 ++++++ 4 files changed, 353 insertions(+), 171 deletions(-) create mode 100644 Dalamud/Utility/ThreadBoundTaskScheduler.cs diff --git a/Dalamud/Game/Framework.cs b/Dalamud/Game/Framework.cs index ce34f2c06..6520ca5c8 100644 --- a/Dalamud/Game/Framework.cs +++ b/Dalamud/Game/Framework.cs @@ -1,3 +1,4 @@ +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Linq; @@ -41,11 +42,13 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework [ServiceManager.ServiceDependency] private readonly DalamudConfiguration configuration = Service.Get(); - private readonly object runOnNextTickTaskListSync = new(); - private List runOnNextTickTaskList = new(); - private List runOnNextTickTaskList2 = new(); + private readonly CancellationTokenSource frameworkDestroy; + private readonly ThreadBoundTaskScheduler frameworkThreadTaskScheduler; - private Thread? frameworkUpdateThread; + private readonly ConcurrentDictionary + tickDelayedTaskCompletionSources = new(); + + private ulong tickCounter; [ServiceManager.ServiceConstructor] private Framework(TargetSigScanner sigScanner, GameLifecycle lifecycle) @@ -56,6 +59,14 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework this.addressResolver = new FrameworkAddressResolver(); this.addressResolver.Setup(sigScanner); + this.frameworkDestroy = new(); + this.frameworkThreadTaskScheduler = new(); + this.FrameworkThreadTaskFactory = new( + this.frameworkDestroy.Token, + TaskCreationOptions.None, + TaskContinuationOptions.None, + this.frameworkThreadTaskScheduler); + this.updateHook = Hook.FromAddress(this.addressResolver.TickAddress, this.HandleFrameworkUpdate); this.destroyHook = Hook.FromAddress(this.addressResolver.DestroyAddress, this.HandleFrameworkDestroy); @@ -92,14 +103,17 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework /// public DateTime LastUpdateUTC { get; private set; } = DateTime.MinValue; + /// + public TaskFactory FrameworkThreadTaskFactory { get; } + /// public TimeSpan UpdateDelta { get; private set; } = TimeSpan.Zero; /// - public bool IsInFrameworkUpdateThread => Thread.CurrentThread == this.frameworkUpdateThread; + public bool IsInFrameworkUpdateThread => this.frameworkThreadTaskScheduler.IsOnBoundThread; /// - public bool IsFrameworkUnloading { get; internal set; } + public bool IsFrameworkUnloading => this.frameworkDestroy.IsCancellationRequested; /// /// Gets the list of update sub-delegates that didn't get updated this frame. @@ -111,6 +125,19 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework /// internal bool DispatchUpdateEvents { get; set; } = true; + /// + public Task DelayTicks(long numTicks, CancellationToken cancellationToken = default) + { + if (this.frameworkDestroy.IsCancellationRequested) + return Task.FromCanceled(this.frameworkDestroy.Token); + if (numTicks <= 0) + return Task.CompletedTask; + + var tcs = new TaskCompletionSource(); + this.tickDelayedTaskCompletionSources[tcs] = (this.tickCounter + (ulong)numTicks, cancellationToken); + return tcs.Task; + } + /// public Task RunOnFrameworkThread(Func func) => this.IsInFrameworkUpdateThread || this.IsFrameworkUnloading ? Task.FromResult(func()) : this.RunOnTick(func); @@ -157,20 +184,16 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework return Task.FromCanceled(cts.Token); } - var tcs = new TaskCompletionSource(); - lock (this.runOnNextTickTaskListSync) - { - this.runOnNextTickTaskList.Add(new RunOnNextTickTaskFunc() + if (cancellationToken == default) + cancellationToken = this.FrameworkThreadTaskFactory.CancellationToken; + return this.FrameworkThreadTaskFactory.ContinueWhenAll( + new[] { - RemainingTicks = delayTicks, - RunAfterTickCount = Environment.TickCount64 + (long)Math.Ceiling(delay.TotalMilliseconds), - CancellationToken = cancellationToken, - TaskCompletionSource = tcs, - Func = func, - }); - } - - return tcs.Task; + Task.Delay(delay, cancellationToken), + this.DelayTicks(delayTicks, cancellationToken), + }, + _ => func(), + cancellationToken); } /// @@ -186,20 +209,16 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework return Task.FromCanceled(cts.Token); } - var tcs = new TaskCompletionSource(); - lock (this.runOnNextTickTaskListSync) - { - this.runOnNextTickTaskList.Add(new RunOnNextTickTaskAction() + if (cancellationToken == default) + cancellationToken = this.FrameworkThreadTaskFactory.CancellationToken; + return this.FrameworkThreadTaskFactory.ContinueWhenAll( + new[] { - RemainingTicks = delayTicks, - RunAfterTickCount = Environment.TickCount64 + (long)Math.Ceiling(delay.TotalMilliseconds), - CancellationToken = cancellationToken, - TaskCompletionSource = tcs, - Action = action, - }); - } - - return tcs.Task; + Task.Delay(delay, cancellationToken), + this.DelayTicks(delayTicks, cancellationToken), + }, + _ => action(), + cancellationToken); } /// @@ -215,20 +234,16 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework return Task.FromCanceled(cts.Token); } - var tcs = new TaskCompletionSource>(); - lock (this.runOnNextTickTaskListSync) - { - this.runOnNextTickTaskList.Add(new RunOnNextTickTaskFunc>() + if (cancellationToken == default) + cancellationToken = this.FrameworkThreadTaskFactory.CancellationToken; + return this.FrameworkThreadTaskFactory.ContinueWhenAll( + new[] { - RemainingTicks = delayTicks, - RunAfterTickCount = Environment.TickCount64 + (long)Math.Ceiling(delay.TotalMilliseconds), - CancellationToken = cancellationToken, - TaskCompletionSource = tcs, - Func = func, - }); - } - - return tcs.Task.ContinueWith(x => x.Result, cancellationToken).Unwrap(); + Task.Delay(delay, cancellationToken), + this.DelayTicks(delayTicks, cancellationToken), + }, + _ => func(), + cancellationToken).Unwrap(); } /// @@ -244,20 +259,16 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework return Task.FromCanceled(cts.Token); } - var tcs = new TaskCompletionSource(); - lock (this.runOnNextTickTaskListSync) - { - this.runOnNextTickTaskList.Add(new RunOnNextTickTaskFunc() + if (cancellationToken == default) + cancellationToken = this.FrameworkThreadTaskFactory.CancellationToken; + return this.FrameworkThreadTaskFactory.ContinueWhenAll( + new[] { - RemainingTicks = delayTicks, - RunAfterTickCount = Environment.TickCount64 + (long)Math.Ceiling(delay.TotalMilliseconds), - CancellationToken = cancellationToken, - TaskCompletionSource = tcs, - Func = func, - }); - } - - return tcs.Task.ContinueWith(x => x.Result, cancellationToken).Unwrap(); + Task.Delay(delay, cancellationToken), + this.DelayTicks(delayTicks, cancellationToken), + }, + _ => func(), + cancellationToken).Unwrap(); } /// @@ -333,23 +344,9 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework } } - private void RunPendingTickTasks() - { - if (this.runOnNextTickTaskList.Count == 0 && this.runOnNextTickTaskList2.Count == 0) - return; - - for (var i = 0; i < 2; i++) - { - lock (this.runOnNextTickTaskListSync) - (this.runOnNextTickTaskList, this.runOnNextTickTaskList2) = (this.runOnNextTickTaskList2, this.runOnNextTickTaskList); - - this.runOnNextTickTaskList2.RemoveAll(x => x.Run()); - } - } - private bool HandleFrameworkUpdate(IntPtr framework) { - this.frameworkUpdateThread ??= Thread.CurrentThread; + this.frameworkThreadTaskScheduler.BoundThread ??= Thread.CurrentThread; ThreadSafety.MarkMainThread(); @@ -381,18 +378,30 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework this.LastUpdate = DateTime.Now; this.LastUpdateUTC = DateTime.UtcNow; + this.tickCounter++; + foreach (var (k, (expiry, ct)) in this.tickDelayedTaskCompletionSources) + { + if (ct.IsCancellationRequested) + k.SetCanceled(ct); + else if (expiry <= this.tickCounter) + k.SetResult(); + else + continue; + + this.tickDelayedTaskCompletionSources.Remove(k, out _); + } if (StatsEnabled) { StatsStopwatch.Restart(); - this.RunPendingTickTasks(); + this.frameworkThreadTaskScheduler.Run(); StatsStopwatch.Stop(); - AddToStats(nameof(this.RunPendingTickTasks), StatsStopwatch.Elapsed.TotalMilliseconds); + AddToStats(nameof(this.frameworkThreadTaskScheduler), StatsStopwatch.Elapsed.TotalMilliseconds); } else { - this.RunPendingTickTasks(); + this.frameworkThreadTaskScheduler.Run(); } if (StatsEnabled && this.Update != null) @@ -404,7 +413,7 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework // Cleanup handlers that are no longer being called foreach (var key in this.NonUpdatedSubDelegates) { - if (key == nameof(this.RunPendingTickTasks)) + if (key == nameof(this.FrameworkThreadTaskFactory)) continue; if (StatsHistory[key].Count > 0) @@ -431,8 +440,11 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework private bool HandleFrameworkDestroy(IntPtr framework) { - this.IsFrameworkUnloading = true; + this.frameworkDestroy.Cancel(); this.DispatchUpdateEvents = false; + foreach (var k in this.tickDelayedTaskCompletionSources.Keys) + k.SetCanceled(this.frameworkDestroy.Token); + this.tickDelayedTaskCompletionSources.Clear(); // All the same, for now... this.lifecycle.SetShuttingDown(); @@ -440,95 +452,12 @@ internal sealed class Framework : IDisposable, IServiceType, IFramework Log.Information("Framework::Destroy!"); Service.Get().Unload(); - this.RunPendingTickTasks(); + this.frameworkThreadTaskScheduler.Run(); ServiceManager.WaitForServiceUnload(); Log.Information("Framework::Destroy OK!"); return this.destroyHook.OriginalDisposeSafe(framework); } - - private abstract class RunOnNextTickTaskBase - { - internal int RemainingTicks { get; set; } - - internal long RunAfterTickCount { get; init; } - - internal CancellationToken CancellationToken { get; init; } - - internal bool Run() - { - if (this.CancellationToken.IsCancellationRequested) - { - this.CancelImpl(); - return true; - } - - if (this.RemainingTicks > 0) - this.RemainingTicks -= 1; - if (this.RemainingTicks > 0) - return false; - - if (this.RunAfterTickCount > Environment.TickCount64) - return false; - - this.RunImpl(); - - return true; - } - - protected abstract void RunImpl(); - - protected abstract void CancelImpl(); - } - - private class RunOnNextTickTaskFunc : RunOnNextTickTaskBase - { - internal TaskCompletionSource TaskCompletionSource { get; init; } - - internal Func Func { get; init; } - - protected override void RunImpl() - { - try - { - this.TaskCompletionSource.SetResult(this.Func()); - } - catch (Exception ex) - { - this.TaskCompletionSource.SetException(ex); - } - } - - protected override void CancelImpl() - { - this.TaskCompletionSource.SetCanceled(); - } - } - - private class RunOnNextTickTaskAction : RunOnNextTickTaskBase - { - internal TaskCompletionSource TaskCompletionSource { get; init; } - - internal Action Action { get; init; } - - protected override void RunImpl() - { - try - { - this.Action(); - this.TaskCompletionSource.SetResult(); - } - catch (Exception ex) - { - this.TaskCompletionSource.SetException(ex); - } - } - - protected override void CancelImpl() - { - this.TaskCompletionSource.SetCanceled(); - } - } } /// @@ -561,7 +490,10 @@ internal class FrameworkPluginScoped : IDisposable, IServiceType, IFramework /// public DateTime LastUpdateUTC => this.frameworkService.LastUpdateUTC; - + + /// + public TaskFactory FrameworkThreadTaskFactory => this.frameworkService.FrameworkThreadTaskFactory; + /// public TimeSpan UpdateDelta => this.frameworkService.UpdateDelta; @@ -579,6 +511,10 @@ internal class FrameworkPluginScoped : IDisposable, IServiceType, IFramework this.Update = null; } + /// + public Task DelayTicks(long numTicks, CancellationToken cancellationToken = default) => + this.frameworkService.DelayTicks(numTicks, cancellationToken); + /// public Task RunOnFrameworkThread(Func func) => this.frameworkService.RunOnFrameworkThread(func); diff --git a/Dalamud/Interface/Internal/Windows/Data/Widgets/TaskSchedulerWidget.cs b/Dalamud/Interface/Internal/Windows/Data/Widgets/TaskSchedulerWidget.cs index d1ac51ad5..c6d8c4e8b 100644 --- a/Dalamud/Interface/Internal/Windows/Data/Widgets/TaskSchedulerWidget.cs +++ b/Dalamud/Interface/Internal/Windows/Data/Widgets/TaskSchedulerWidget.cs @@ -1,13 +1,22 @@ // ReSharper disable MethodSupportsCancellation // Using alternative method of cancelling tasks by throwing exceptions. +using System.IO; +using System.Linq; +using System.Net.Http; using System.Reflection; +using System.Text; using System.Threading; using System.Threading.Tasks; using Dalamud.Game; using Dalamud.Interface.Colors; +using Dalamud.Interface.Components; +using Dalamud.Interface.ImGuiFileDialog; using Dalamud.Interface.Utility; +using Dalamud.Interface.Utility.Raii; using Dalamud.Logging.Internal; +using Dalamud.Utility; + using ImGuiNET; using Serilog; @@ -18,6 +27,12 @@ namespace Dalamud.Interface.Internal.Windows.Data.Widgets; /// internal class TaskSchedulerWidget : IDataWindowWidget { + private readonly FileDialogManager fileDialogManager = new(); + private readonly byte[] urlBytes = new byte[2048]; + private readonly byte[] localPathBytes = new byte[2048]; + + private Task? downloadTask = null; + private (long Downloaded, long Total, float Percentage) downloadState; private CancellationTokenSource taskSchedulerCancelSource = new(); /// @@ -33,11 +48,16 @@ internal class TaskSchedulerWidget : IDataWindowWidget public void Load() { this.Ready = true; + Encoding.UTF8.GetBytes( + "https://geo.mirror.pkgbuild.com/iso/2024.01.01/archlinux-2024.01.01-x86_64.iso", + this.urlBytes); } /// public void Draw() { + var framework = Service.Get(); + if (ImGui.Button("Clear list")) { TaskTracker.Clear(); @@ -84,8 +104,7 @@ internal class TaskSchedulerWidget : IDataWindowWidget { Thread.Sleep(200); - string a = null; - a.Contains("dalamud"); // Intentional null exception. + _ = ((string)null)!.Contains("dalamud"); // Intentional null exception. }); } @@ -94,36 +113,156 @@ internal class TaskSchedulerWidget : IDataWindowWidget if (ImGui.Button("ASAP")) { - Task.Run(async () => await Service.Get().RunOnTick(() => { }, cancellationToken: this.taskSchedulerCancelSource.Token)); + _ = framework.RunOnTick(() => Log.Information("Framework.Update - ASAP"), cancellationToken: this.taskSchedulerCancelSource.Token); } ImGui.SameLine(); if (ImGui.Button("In 1s")) { - Task.Run(async () => await Service.Get().RunOnTick(() => { }, cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1))); + _ = framework.RunOnTick(() => Log.Information("Framework.Update - In 1s"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1)); } ImGui.SameLine(); if (ImGui.Button("In 60f")) { - Task.Run(async () => await Service.Get().RunOnTick(() => { }, cancellationToken: this.taskSchedulerCancelSource.Token, delayTicks: 60)); + _ = framework.RunOnTick(() => Log.Information("Framework.Update - In 60f"), cancellationToken: this.taskSchedulerCancelSource.Token, delayTicks: 60); + } + + ImGui.SameLine(); + + if (ImGui.Button("In 1s+120f")) + { + _ = framework.RunOnTick(() => Log.Information("Framework.Update - In 1s+120f"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1), delayTicks: 120); + } + + ImGui.SameLine(); + + if (ImGui.Button("In 2s+60f")) + { + _ = framework.RunOnTick(() => Log.Information("Framework.Update - In 2s+60f"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(2), delayTicks: 60); + } + + ImGui.SameLine(); + + if (ImGui.Button("Every 60 frames")) + { + _ = framework.RunOnTick( + async () => + { + for (var i = 0L; ; i++) + { + Log.Information($"Loop #{i}; MainThread={ThreadSafety.IsMainThread}"); + await framework.DelayTicks(60, this.taskSchedulerCancelSource.Token); + } + }, + cancellationToken: this.taskSchedulerCancelSource.Token); } ImGui.SameLine(); if (ImGui.Button("Error in 1s")) { - Task.Run(async () => await Service.Get().RunOnTick(() => throw new Exception("Test Exception"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1))); + _ = framework.RunOnTick(() => throw new Exception("Test Exception"), cancellationToken: this.taskSchedulerCancelSource.Token, delay: TimeSpan.FromSeconds(1)); } ImGui.SameLine(); if (ImGui.Button("As long as it's in Framework Thread")) { - Task.Run(async () => await Service.Get().RunOnFrameworkThread(() => { Log.Information("Task dispatched from non-framework.update thread"); })); - Service.Get().RunOnFrameworkThread(() => { Log.Information("Task dispatched from framework.update thread"); }).Wait(); + Task.Run(async () => await framework.RunOnFrameworkThread(() => { Log.Information("Task dispatched from non-framework.update thread"); })); + framework.RunOnFrameworkThread(() => { Log.Information("Task dispatched from framework.update thread"); }).Wait(); + } + + if (ImGui.CollapsingHeader("Download")) + { + ImGui.InputText("URL", this.urlBytes, (uint)this.urlBytes.Length); + ImGui.InputText("Local Path", this.localPathBytes, (uint)this.localPathBytes.Length); + ImGui.SameLine(); + + if (ImGuiComponents.IconButton("##localpathpicker", FontAwesomeIcon.File)) + { + var defaultFileName = Encoding.UTF8.GetString(this.urlBytes).Split('\0', 2)[0].Split('/').Last(); + this.fileDialogManager.SaveFileDialog( + "Choose a local path", + "*", + defaultFileName, + string.Empty, + (accept, newPath) => + { + if (accept) + { + this.localPathBytes.AsSpan().Clear(); + Encoding.UTF8.GetBytes(newPath, this.localPathBytes.AsSpan()); + } + }); + } + + ImGui.TextUnformatted($"{this.downloadState.Downloaded:##,###}/{this.downloadState.Total:##,###} ({this.downloadState.Percentage:0.00}%)"); + + using var disabled = + ImRaii.Disabled(this.downloadTask?.IsCompleted is false || this.localPathBytes[0] == 0); + ImGui.AlignTextToFramePadding(); + ImGui.TextUnformatted("Download"); + ImGui.SameLine(); + var downloadUsingGlobalScheduler = ImGui.Button("using default scheduler"); + ImGui.SameLine(); + var downloadUsingFramework = ImGui.Button("using Framework.Update"); + if (downloadUsingGlobalScheduler || downloadUsingFramework) + { + var url = Encoding.UTF8.GetString(this.urlBytes).Split('\0', 2)[0]; + var localPath = Encoding.UTF8.GetString(this.localPathBytes).Split('\0', 2)[0]; + var ct = this.taskSchedulerCancelSource.Token; + this.downloadState = default; + var factory = downloadUsingGlobalScheduler + ? Task.Factory + : framework.FrameworkThreadTaskFactory; + this.downloadState = default; + this.downloadTask = factory.StartNew( + async () => + { + try + { + await using var to = File.Create(localPath); + using var client = new HttpClient(); + using var conn = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, ct); + this.downloadState.Total = conn.Content.Headers.ContentLength ?? -1L; + await using var from = conn.Content.ReadAsStream(ct); + var buffer = new byte[8192]; + while (true) + { + if (downloadUsingFramework) + ThreadSafety.AssertMainThread(); + if (downloadUsingGlobalScheduler) + ThreadSafety.AssertNotMainThread(); + var len = await from.ReadAsync(buffer, ct); + if (len == 0) + break; + await to.WriteAsync(buffer.AsMemory(0, len), ct); + this.downloadState.Downloaded += len; + if (this.downloadState.Total >= 0) + { + this.downloadState.Percentage = + (100f * this.downloadState.Downloaded) / this.downloadState.Total; + } + } + } + catch (Exception e) + { + Log.Error(e, "Failed to download {from} to {to}.", url, localPath); + try + { + File.Delete(localPath); + } + catch + { + // ignore + } + } + }, + cancellationToken: ct).Unwrap(); + } } if (ImGui.Button("Drown in tasks")) @@ -244,6 +383,8 @@ internal class TaskSchedulerWidget : IDataWindowWidget ImGui.PopStyleColor(1); } + + this.fileDialogManager.Draw(); } private async Task TestTaskInTaskDelay(CancellationToken token) diff --git a/Dalamud/Plugin/Services/IFramework.cs b/Dalamud/Plugin/Services/IFramework.cs index ca33c5867..a93abd252 100644 --- a/Dalamud/Plugin/Services/IFramework.cs +++ b/Dalamud/Plugin/Services/IFramework.cs @@ -29,6 +29,11 @@ public interface IFramework /// public DateTime LastUpdateUTC { get; } + /// + /// Gets a that runs tasks during Framework Update event. + /// + public TaskFactory FrameworkThreadTaskFactory { get; } + /// /// Gets the delta between the last Framework Update and the currently executing one. /// @@ -44,6 +49,14 @@ public interface IFramework /// public bool IsFrameworkUnloading { get; } + /// + /// Returns a task that completes after the given number of ticks. + /// + /// Number of ticks to delay. + /// The cancellation token. + /// A new that gets resolved after specified number of ticks happen. + public Task DelayTicks(long numTicks, CancellationToken cancellationToken = default); + /// /// Run given function right away if this function has been called from game's Framework.Update thread, or otherwise run on next Framework.Update call. /// @@ -65,6 +78,7 @@ public interface IFramework /// Return type. /// Function to call. /// Task representing the pending or already completed function. + [Obsolete($"Use {nameof(RunOnTick)} instead.")] public Task RunOnFrameworkThread(Func> func); /// @@ -72,6 +86,7 @@ public interface IFramework /// /// Function to call. /// Task representing the pending or already completed function. + [Obsolete($"Use {nameof(RunOnTick)} instead.")] public Task RunOnFrameworkThread(Func func); /// diff --git a/Dalamud/Utility/ThreadBoundTaskScheduler.cs b/Dalamud/Utility/ThreadBoundTaskScheduler.cs new file mode 100644 index 000000000..4b6de29ff --- /dev/null +++ b/Dalamud/Utility/ThreadBoundTaskScheduler.cs @@ -0,0 +1,90 @@ +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Dalamud.Utility; + +/// +/// A task scheduler that runs tasks on a specific thread. +/// +internal class ThreadBoundTaskScheduler : TaskScheduler +{ + private const byte Scheduled = 0; + private const byte Running = 1; + + private readonly ConcurrentDictionary scheduledTasks = new(); + + /// + /// Initializes a new instance of the class. + /// + /// The thread to bind this task scheduelr to. + public ThreadBoundTaskScheduler(Thread? boundThread = null) + { + this.BoundThread = boundThread; + } + + /// + /// Gets or sets the thread this task scheduler is bound to. + /// + public Thread? BoundThread { get; set; } + + /// + /// Gets a value indicating whether we're on the bound thread. + /// + public bool IsOnBoundThread => Thread.CurrentThread == this.BoundThread; + + /// + /// Runs queued tasks. + /// + public void Run() + { + foreach (var task in this.scheduledTasks.Keys) + { + if (!this.scheduledTasks.TryUpdate(task, Running, Scheduled)) + continue; + + _ = this.TryExecuteTask(task); + } + } + + /// + protected override IEnumerable GetScheduledTasks() + { + return this.scheduledTasks.Keys; + } + + /// + protected override void QueueTask(Task task) + { + this.scheduledTasks[task] = Scheduled; + } + + /// + protected override bool TryDequeue(Task task) + { + if (!this.scheduledTasks.TryRemove(task, out _)) + return false; + return true; + } + + /// + protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) + { + if (!this.IsOnBoundThread) + return false; + + if (taskWasPreviouslyQueued && !this.scheduledTasks.TryUpdate(task, Running, Scheduled)) + return false; + + _ = this.TryExecuteTask(task); + return true; + } + + private new bool TryExecuteTask(Task task) + { + var r = base.TryExecuteTask(task); + this.scheduledTasks.Remove(task, out _); + return r; + } +}