use dbcontext factory?

This commit is contained in:
Stanley Dimant 2024-10-29 14:14:11 +01:00
parent 74633e1337
commit 19d819045c
3 changed files with 38 additions and 30 deletions

View file

@ -26,7 +26,7 @@ public class OAuthController : AuthControllerBase
private static readonly ConcurrentDictionary<string, string> _cookieOAuthResponse = [];
public OAuthController(ILogger<OAuthController> logger,
IHttpContextAccessor accessor, MareDbContext mareDbContext,
IHttpContextAccessor accessor, IDbContextFactory<MareDbContext> mareDbContext,
SecretKeyAuthenticatorService secretKeyAuthenticatorService,
IConfigurationService<AuthServiceConfiguration> configuration,
IRedisDatabase redisDb, GeoIPService geoIPProvider)
@ -135,7 +135,9 @@ public class OAuthController : AuthControllerBase
if (discordUserId == 0)
return BadRequest("Failed to get Discord ID from login token");
var mareUser = await MareDbContext.LodeStoneAuth.Include(u => u.User).SingleOrDefaultAsync(u => u.DiscordId == discordUserId);
using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
var mareUser = await dbContext.LodeStoneAuth.Include(u => u.User).SingleOrDefaultAsync(u => u.DiscordId == discordUserId);
if (mareUser == null)
return BadRequest("Could not find a Mare user associated to this Discord account.");
@ -213,11 +215,12 @@ public class OAuthController : AuthControllerBase
public async Task<Dictionary<string, string>> GetAvailableUIDs()
{
string primaryUid = HttpContext.User.Claims.Single(c => string.Equals(c.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value;
using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
var mareUser = await MareDbContext.Auth.AsNoTracking().Include(u => u.User).FirstOrDefaultAsync(f => f.UserUID == primaryUid).ConfigureAwait(false);
var mareUser = await dbContext.Auth.AsNoTracking().Include(u => u.User).FirstOrDefaultAsync(f => f.UserUID == primaryUid).ConfigureAwait(false);
if (mareUser == null || mareUser.User == null) return [];
var uid = mareUser.User.UID;
var allUids = await MareDbContext.Auth.AsNoTracking().Include(u => u.User).Where(a => a.UserUID == uid || a.PrimaryUserUID == uid).ToListAsync().ConfigureAwait(false);
var allUids = await dbContext.Auth.AsNoTracking().Include(u => u.User).Where(a => a.UserUID == uid || a.PrimaryUserUID == uid).ToListAsync().ConfigureAwait(false);
var result = allUids.OrderBy(u => u.UserUID == uid ? 0 : 1).ThenBy(u => u.UserUID).Select(u => (u.UserUID, u.User.Alias)).ToDictionary();
return result;
}
@ -226,10 +229,12 @@ public class OAuthController : AuthControllerBase
[HttpPost(MareAuth.OAuth_CreateOAuth)]
public async Task<IActionResult> CreateTokenWithOAuth(string uid, string charaIdent)
{
return await AuthenticateOAuthInternal(uid, charaIdent);
using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
return await AuthenticateOAuthInternal(dbContext, uid, charaIdent);
}
private async Task<IActionResult> AuthenticateOAuthInternal(string requestedUid, string charaIdent)
private async Task<IActionResult> AuthenticateOAuthInternal(MareDbContext dbContext, string requestedUid, string charaIdent)
{
try
{
@ -241,7 +246,7 @@ public class OAuthController : AuthControllerBase
var authResult = await SecretKeyAuthenticatorService.AuthorizeOauthAsync(ip, primaryUid, requestedUid);
return await GenericAuthResponse(charaIdent, authResult);
return await GenericAuthResponse(dbContext, charaIdent, authResult);
}
catch (Exception ex)
{