Refactor repository methods to async and update credential logic

This commit is contained in:
Aaron Po
2026-01-22 11:14:23 -05:00
parent fd544dbd34
commit 82db763951
11 changed files with 84 additions and 67 deletions

View File

@@ -164,7 +164,13 @@ CREATE TABLE UserCredential -- delete credentials when user account is deleted
Hash NVARCHAR(MAX) NOT NULL, Hash NVARCHAR(MAX) NOT NULL,
-- uses argon2 -- uses argon2
Timer ROWVERSION, IsRevoked BIT NOT NULL
CONSTRAINT DF_UserCredential_IsRevoked DEFAULT 0,
RevokedAt DATETIME NULL,
Timer ROWVERSION,
CONSTRAINT PK_UserCredential CONSTRAINT PK_UserCredential
PRIMARY KEY (UserCredentialID), PRIMARY KEY (UserCredentialID),
@@ -173,9 +179,6 @@ CREATE TABLE UserCredential -- delete credentials when user account is deleted
FOREIGN KEY (UserAccountID) FOREIGN KEY (UserAccountID)
REFERENCES UserAccount(UserAccountID) REFERENCES UserAccount(UserAccountID)
ON DELETE CASCADE, ON DELETE CASCADE,
CONSTRAINT AK_UserCredential_UserAccountID
UNIQUE (UserAccountID)
); );
CREATE NONCLUSTERED INDEX IX_UserCredential_UserAccount CREATE NONCLUSTERED INDEX IX_UserCredential_UserAccount

View File

@@ -1,6 +1,6 @@
CREATE OR ALTER PROCEDURE dbo.USP_AddUserCredential( CREATE OR ALTER PROCEDURE dbo.USP_AddUpdateUserCredential(
@UserAccountId uniqueidentifier, @UserAccountId UNIQUEIDENTIFIER,
@Hash nvarchar(max) @Hash NVARCHAR(MAX)
) )
AS AS
BEGIN BEGIN
@@ -16,14 +16,14 @@ BEGIN
) )
THROW 50001, 'UserAccountID does not exist.', 1; THROW 50001, 'UserAccountID does not exist.', 1;
IF EXISTS (
SELECT 1
FROM dbo.UserCredential
WHERE UserAccountID = @UserAccountId
)
THROW 50002, 'UserCredential for this UserAccountID already exists.', 1;
-- invalidate old credentials
UPDATE dbo.UserCredential
SET IsRevoked = 1,
RevokedAt = GETDATE()
WHERE UserAccountId = @UserAccountId
AND IsRevoked = 0;
INSERT INTO dbo.UserCredential INSERT INTO dbo.UserCredential
(UserAccountId, Hash) (UserAccountId, Hash)
VALUES VALUES

View File

@@ -0,0 +1,18 @@
CREATE OR ALTER PROCEDURE dbo.USP_GetUserCredentialByUserAccountId(
@UserAccountId UNIQUEIDENTIFIER
)
AS
BEGIN
SET NOCOUNT ON;
SET XACT_ABORT ON;
SELECT
UserCredentialId,
UserAccountId,
Hash,
IsRevoked,
CreatedAt,
RevokedAt
FROM dbo.UserCredential
WHERE UserAccountId = @UserAccountId AND IsRevoked = 0;
END;

View File

@@ -13,11 +13,11 @@ namespace DataAccessLayer.Repositories
return connection; return connection;
} }
public abstract Task Add(T entity); public abstract Task AddAsync(T entity);
public abstract Task<IEnumerable<T>> GetAll(int? limit, int? offset); public abstract Task<IEnumerable<T>> GetAllAsync(int? limit, int? offset);
public abstract Task<T?> GetById(Guid id); public abstract Task<T?> GetByIdAsync(Guid id);
public abstract Task Update(T entity); public abstract Task UpdateAsync(T entity);
public abstract Task Delete(Guid id); public abstract Task DeleteAsync(Guid id);
protected abstract T MapToEntity(SqlDataReader reader); protected abstract T MapToEntity(SqlDataReader reader);
} }

View File

@@ -4,12 +4,12 @@ namespace DataAccessLayer.Repositories.UserAccount
{ {
public interface IUserAccountRepository public interface IUserAccountRepository
{ {
Task Add(Entities.UserAccount userAccount); Task AddAsync(Entities.UserAccount userAccount);
Task<Entities.UserAccount?> GetById(Guid id); Task<Entities.UserAccount?> GetByIdAsync(Guid id);
Task<IEnumerable<Entities.UserAccount>> GetAll(int? limit, int? offset); Task<IEnumerable<Entities.UserAccount>> GetAllAsync(int? limit, int? offset);
Task Update(Entities.UserAccount userAccount); Task UpdateAsync(Entities.UserAccount userAccount);
Task Delete(Guid id); Task DeleteAsync(Guid id);
Task<Entities.UserAccount?> GetByUsername(string username); Task<Entities.UserAccount?> GetByUsernameAsync(string username);
Task<Entities.UserAccount?> GetByEmail(string email); Task<Entities.UserAccount?> GetByEmailAsync(string email);
} }
} }

View File

@@ -7,7 +7,7 @@ namespace DataAccessLayer.Repositories.UserAccount
public class UserAccountRepository(ISqlConnectionFactory connectionFactory) public class UserAccountRepository(ISqlConnectionFactory connectionFactory)
: Repository<Entities.UserAccount>(connectionFactory), IUserAccountRepository : Repository<Entities.UserAccount>(connectionFactory), IUserAccountRepository
{ {
public override async Task Add(Entities.UserAccount userAccount) public override async Task AddAsync(Entities.UserAccount userAccount)
{ {
await using var connection = await CreateConnection(); await using var connection = await CreateConnection();
await using var command = new SqlCommand("usp_CreateUserAccount", connection); await using var command = new SqlCommand("usp_CreateUserAccount", connection);
@@ -23,7 +23,7 @@ namespace DataAccessLayer.Repositories.UserAccount
await command.ExecuteNonQueryAsync(); await command.ExecuteNonQueryAsync();
} }
public override async Task<Entities.UserAccount?> GetById(Guid id) public override async Task<Entities.UserAccount?> GetByIdAsync(Guid id)
{ {
await using var connection = await CreateConnection(); await using var connection = await CreateConnection();
await using var command = new SqlCommand("usp_GetUserAccountById", connection) await using var command = new SqlCommand("usp_GetUserAccountById", connection)
@@ -37,7 +37,7 @@ namespace DataAccessLayer.Repositories.UserAccount
return await reader.ReadAsync() ? MapToEntity(reader) : null; return await reader.ReadAsync() ? MapToEntity(reader) : null;
} }
public override async Task<IEnumerable<Entities.UserAccount>> GetAll(int? limit, int? offset) public override async Task<IEnumerable<Entities.UserAccount>> GetAllAsync(int? limit, int? offset)
{ {
await using var connection = await CreateConnection(); await using var connection = await CreateConnection();
await using var command = new SqlCommand("usp_GetAllUserAccounts", connection); await using var command = new SqlCommand("usp_GetAllUserAccounts", connection);
@@ -60,7 +60,7 @@ namespace DataAccessLayer.Repositories.UserAccount
return users; return users;
} }
public override async Task Update(Entities.UserAccount userAccount) public override async Task UpdateAsync(Entities.UserAccount userAccount)
{ {
await using var connection = await CreateConnection(); await using var connection = await CreateConnection();
await using var command = new SqlCommand("usp_UpdateUserAccount", connection); await using var command = new SqlCommand("usp_UpdateUserAccount", connection);
@@ -76,7 +76,7 @@ namespace DataAccessLayer.Repositories.UserAccount
await command.ExecuteNonQueryAsync(); await command.ExecuteNonQueryAsync();
} }
public override async Task Delete(Guid id) public override async Task DeleteAsync(Guid id)
{ {
await using var connection = await CreateConnection(); await using var connection = await CreateConnection();
await using var command = new SqlCommand("usp_DeleteUserAccount", connection); await using var command = new SqlCommand("usp_DeleteUserAccount", connection);
@@ -86,7 +86,7 @@ namespace DataAccessLayer.Repositories.UserAccount
await command.ExecuteNonQueryAsync(); await command.ExecuteNonQueryAsync();
} }
public async Task<Entities.UserAccount?> GetByUsername(string username) public async Task<Entities.UserAccount?> GetByUsernameAsync(string username)
{ {
await using var connection = await CreateConnection(); await using var connection = await CreateConnection();
await using var command = new SqlCommand("usp_GetUserAccountByUsername", connection); await using var command = new SqlCommand("usp_GetUserAccountByUsername", connection);
@@ -98,7 +98,7 @@ namespace DataAccessLayer.Repositories.UserAccount
return await reader.ReadAsync() ? MapToEntity(reader) : null; return await reader.ReadAsync() ? MapToEntity(reader) : null;
} }
public async Task<Entities.UserAccount?> GetByEmail(string email) public async Task<Entities.UserAccount?> GetByEmailAsync(string email)
{ {
await using var connection = await CreateConnection(); await using var connection = await CreateConnection();
await using var command = new SqlCommand("usp_GetUserAccountByEmail", connection); await using var command = new SqlCommand("usp_GetUserAccountByEmail", connection);

View File

@@ -1,11 +1,7 @@
namespace DataAccessLayer.Repositories.UserCredential; using DataAccessLayer.Entities;
public interface IUserCredentialRepository public interface IUserCredentialRepository
{ {
Task Add(Entities.UserCredential credential); Task RotateCredentialAsync(Guid userAccountId, UserCredential credential);
Task<Entities.UserCredential?> GetById(Guid userCredentialId); Task<UserCredential?> GetActiveCredentialByUserAccountIdAsync(Guid userAccountId);
Task<Entities.UserCredential?> GetByUserAccountId(Guid userAccountId); }
Task<IEnumerable<Entities.UserCredential>> GetAll(int? limit, int? offset);
Task Update(Entities.UserCredential credential);
Task Delete(Guid userCredentialId);
}

View File

@@ -25,8 +25,8 @@ namespace DALTests
}; };
// Act // Act
await _repository.Add(userAccount); await _repository.AddAsync(userAccount);
var retrievedUser = await _repository.GetById(userAccount.UserAccountId); var retrievedUser = await _repository.GetByIdAsync(userAccount.UserAccountId);
// Assert // Assert
Assert.NotNull(retrievedUser); Assert.NotNull(retrievedUser);
@@ -48,10 +48,10 @@ namespace DALTests
CreatedAt = DateTime.UtcNow, CreatedAt = DateTime.UtcNow,
DateOfBirth = new DateTime(1985, 5, 15), DateOfBirth = new DateTime(1985, 5, 15),
}; };
await _repository.Add(userAccount); await _repository.AddAsync(userAccount);
// Act // Act
var retrievedUser = await _repository.GetById(userId); var retrievedUser = await _repository.GetByIdAsync(userId);
// Assert // Assert
Assert.NotNull(retrievedUser); Assert.NotNull(retrievedUser);
@@ -72,12 +72,12 @@ namespace DALTests
CreatedAt = DateTime.UtcNow, CreatedAt = DateTime.UtcNow,
DateOfBirth = new DateTime(1992, 3, 10), DateOfBirth = new DateTime(1992, 3, 10),
}; };
await _repository.Add(userAccount); await _repository.AddAsync(userAccount);
// Act // Act
userAccount.FirstName = "Updated"; userAccount.FirstName = "Updated";
await _repository.Update(userAccount); await _repository.UpdateAsync(userAccount);
var updatedUser = await _repository.GetById(userAccount.UserAccountId); var updatedUser = await _repository.GetByIdAsync(userAccount.UserAccountId);
// Assert // Assert
Assert.NotNull(updatedUser); Assert.NotNull(updatedUser);
@@ -98,11 +98,11 @@ namespace DALTests
CreatedAt = DateTime.UtcNow, CreatedAt = DateTime.UtcNow,
DateOfBirth = new DateTime(1995, 7, 20), DateOfBirth = new DateTime(1995, 7, 20),
}; };
await _repository.Add(userAccount); await _repository.AddAsync(userAccount);
// Act // Act
await _repository.Delete(userAccount.UserAccountId); await _repository.DeleteAsync(userAccount.UserAccountId);
var deletedUser = await _repository.GetById(userAccount.UserAccountId); var deletedUser = await _repository.GetByIdAsync(userAccount.UserAccountId);
// Assert // Assert
Assert.Null(deletedUser); Assert.Null(deletedUser);
@@ -132,11 +132,11 @@ namespace DALTests
CreatedAt = DateTime.UtcNow, CreatedAt = DateTime.UtcNow,
DateOfBirth = new DateTime(1992, 2, 2), DateOfBirth = new DateTime(1992, 2, 2),
}; };
await _repository.Add(user1); await _repository.AddAsync(user1);
await _repository.Add(user2); await _repository.AddAsync(user2);
// Act // Act
var allUsers = await _repository.GetAll(null, null); var allUsers = await _repository.GetAllAsync(null, null);
// Assert // Assert
Assert.NotNull(allUsers); Assert.NotNull(allUsers);
@@ -183,11 +183,11 @@ namespace DALTests
foreach (var user in users) foreach (var user in users)
{ {
await _repository.Add(user); await _repository.AddAsync(user);
} }
// Act // Act
var page = (await _repository.GetAll(2, 0)).ToList(); var page = (await _repository.GetAllAsync(2, 0)).ToList();
// Assert // Assert
Assert.Equal(2, page.Count); Assert.Equal(2, page.Count);
@@ -197,10 +197,10 @@ namespace DALTests
public async Task GetAll_WithPagination_ShouldValidateArguments() public async Task GetAll_WithPagination_ShouldValidateArguments()
{ {
await Assert.ThrowsAsync<ArgumentOutOfRangeException>(async () => await Assert.ThrowsAsync<ArgumentOutOfRangeException>(async () =>
(await _repository.GetAll(0, 0)).ToList() (await _repository.GetAllAsync(0, 0)).ToList()
); );
await Assert.ThrowsAsync<ArgumentOutOfRangeException>(async () => await Assert.ThrowsAsync<ArgumentOutOfRangeException>(async () =>
(await _repository.GetAll(1, -1)).ToList() (await _repository.GetAllAsync(1, -1)).ToList()
); );
} }
} }
@@ -209,7 +209,7 @@ namespace DALTests
{ {
private readonly Dictionary<Guid, UserAccount> _store = new(); private readonly Dictionary<Guid, UserAccount> _store = new();
public Task Add(UserAccount userAccount) public Task AddAsync(UserAccount userAccount)
{ {
if (userAccount.UserAccountId == Guid.Empty) if (userAccount.UserAccountId == Guid.Empty)
{ {
@@ -219,13 +219,13 @@ namespace DALTests
return Task.CompletedTask; return Task.CompletedTask;
} }
public Task<UserAccount?> GetById(Guid id) public Task<UserAccount?> GetByIdAsync(Guid id)
{ {
_store.TryGetValue(id, out var user); _store.TryGetValue(id, out var user);
return Task.FromResult(user is null ? null : Clone(user)); return Task.FromResult(user is null ? null : Clone(user));
} }
public Task<IEnumerable<UserAccount>> GetAll(int? limit, int? offset) public Task<IEnumerable<UserAccount>> GetAllAsync(int? limit, int? offset)
{ {
if (limit.HasValue && limit.Value <= 0) throw new ArgumentOutOfRangeException(nameof(limit)); if (limit.HasValue && limit.Value <= 0) throw new ArgumentOutOfRangeException(nameof(limit));
if (offset.HasValue && offset.Value < 0) throw new ArgumentOutOfRangeException(nameof(offset)); if (offset.HasValue && offset.Value < 0) throw new ArgumentOutOfRangeException(nameof(offset));
@@ -240,26 +240,26 @@ namespace DALTests
return Task.FromResult<IEnumerable<UserAccount>>(query.ToList()); return Task.FromResult<IEnumerable<UserAccount>>(query.ToList());
} }
public Task Update(UserAccount userAccount) public Task UpdateAsync(UserAccount userAccount)
{ {
if (!_store.ContainsKey(userAccount.UserAccountId)) return Task.CompletedTask; if (!_store.ContainsKey(userAccount.UserAccountId)) return Task.CompletedTask;
_store[userAccount.UserAccountId] = Clone(userAccount); _store[userAccount.UserAccountId] = Clone(userAccount);
return Task.CompletedTask; return Task.CompletedTask;
} }
public Task Delete(Guid id) public Task DeleteAsync(Guid id)
{ {
_store.Remove(id); _store.Remove(id);
return Task.CompletedTask; return Task.CompletedTask;
} }
public Task<UserAccount?> GetByUsername(string username) public Task<UserAccount?> GetByUsernameAsync(string username)
{ {
var user = _store.Values.FirstOrDefault(u => u.Username == username); var user = _store.Values.FirstOrDefault(u => u.Username == username);
return Task.FromResult(user is null ? null : Clone(user)); return Task.FromResult(user is null ? null : Clone(user));
} }
public Task<UserAccount?> GetByEmail(string email) public Task<UserAccount?> GetByEmailAsync(string email)
{ {
var user = _store.Values.FirstOrDefault(u => u.Email == email); var user = _store.Values.FirstOrDefault(u => u.Email == email);
return Task.FromResult(user is null ? null : Clone(user)); return Task.FromResult(user is null ? null : Clone(user));

View File

@@ -8,12 +8,12 @@ namespace BusinessLayer.Services
{ {
public async Task<IEnumerable<UserAccount>> GetAllAsync(int? limit = null, int? offset = null) public async Task<IEnumerable<UserAccount>> GetAllAsync(int? limit = null, int? offset = null)
{ {
return await repository.GetAll(limit, offset); return await repository.GetAllAsync(limit, offset);
} }
public async Task<UserAccount?> GetByIdAsync(Guid id) public async Task<UserAccount?> GetByIdAsync(Guid id)
{ {
return await repository.GetById(id); return await repository.GetByIdAsync(id);
} }
} }
} }