A couple of weeks ago I came across an interesting case where a mobile app was calling into the Content Delivery API to retrieve some data, but the calls would fail with a 406 HTTP Status Code. After some investigation, we found that the duplicate refresh tokens were being added to the system. Revoking the tokens would temporarily fix the issue. But often we would need to revoke the tokens again a couple of hours later.
While we did find a couple of concerns with how the mobile application was calling to log in and then how the refresh tokens were being used the more significant issue was that Optimizely was allowing the tokens to duplicate.
Help Desk to the Rescue
A key thing that we found during our implementation is that we needed to add a lock around where the refresh token was being added to the system to prevent duplicates. Since this was in a load-balanced scenario a distributed lock was needed. Fortunately, I was able to find a solid Nuget package, Medallion.Threading, to handle the details of the distributed thread.
Breakdown of the changes
- Create a Custom Authorization Server Provider.
- Create a Distributed Refresh Token Provider
- Create an extension method to wire up authentication with the Custom Authorization Server Provider and Distributed Refresh Token Provider.
- Call our new extension method in startup.cs
CustomAuthorizationServerProvider.cs
internal class CustomAuthorizationServerProvider<TManager, TUser, TKey> : OAuthAuthorizationServerProvider
where TManager : UserManager<TUser, TKey>
where TUser : class, IUser<TKey>
where TKey : IEquatable<TKey>
{
private readonly ContentApiOAuthOptions _options;
private static readonly ILogger Log = LogManager.GetLogger(typeof(CustomAuthorizationServerProvider<TManager, TUser, TKey>));
public CustomAuthorizationServerProvider() : this(ServiceLocator.Current.GetInstance<ContentApiOAuthOptions>())
{
}
public CustomAuthorizationServerProvider(ContentApiOAuthOptions options) : base()
{
_options = options;
}
/// <summary>
/// Called to validate that the origin of the request is a registered "client_id", and that the correct credentials for that client are present on the request.
/// Custom error handling can happen here. Eg: check for missing or wrong value parameter, etc.
/// Call context.SetError() to mark the request as invalid and return an error message to client
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public override Task ValidateClientAuthentication(OAuthValidateClientAuthenticationContext context)
{
string clientId = string.Empty;
string clientSecret = string.Empty;
if (!context.TryGetBasicCredentials(out clientId, out clientSecret))
{
context.TryGetFormCredentials(out clientId, out clientSecret);
}
if (context.ClientId == null)
{
context.Rejected();
context.SetError(OAuthErrors.InvalidClientId, "Client ID must be sent");
return Completed();
}
var grantType = context.Parameters[AuthorisationConstants.GrantType];
if (grantType == null)
{
context.SetError(OAuthErrors.InvalidGrant, "grant_type must be sent");
return Completed();
}
if (grantType.Equals(AuthorisationConstants.RefreshToken) && context.Parameters[AuthorisationConstants.RefreshToken] == null)
{
context.SetError(OAuthErrors.InvalidRefreshToken, "Refresh token must be sent");
return Completed();
}
var client = _options.Clients.FirstOrDefault(x => x.ClientId == context.ClientId);
if (client == null)
{
context.Rejected();
context.SetError(OAuthErrors.InvalidClientId, string.Format("Client '{0}' is not registered in the system.", context.ClientId));
return Completed();
}
context.OwinContext.Set(AuthorisationConstants.GrantType, grantType);
context.OwinContext.Set(AuthorisationConstants.ClientId, client.ClientId);
context.OwinContext.Set(AuthorisationConstants.AllowedOrigin, client.AccessControlAllowOrigin);
var originHeader = context.Request.Headers["Origin"];
if (!string.IsNullOrEmpty(originHeader) && client.AccessControlAllowOrigin != "*" && !originHeader.Equals(client.AccessControlAllowOrigin))
{
context.Rejected();
context.SetError(OAuthErrors.InvalidOrigin, string.Format("Origin '{0}' is not allowed by Access-Control-Allow-Origin", originHeader));
return Completed();
}
context.Validated();
return Completed();
}
/// <summary>
/// Called when a request to the Token endpoint arrives with a "grant_type" of "password"
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public override async Task GrantResourceOwnerCredentials(OAuthGrantResourceOwnerCredentialsContext context)
{
Log.Information($"Authentication Request: {GetRequestInfo(context)}");
var allowedOrigin = context.OwinContext.Get<string>(AuthorisationConstants.AllowedOrigin);
if (allowedOrigin == null) allowedOrigin = "*";
var responseHeaders = context.OwinContext.Response.Headers;
// Only add CORS header here if it does not exist. Otherwise in case client setup the CORS middleware along with OAuth,
// the header will be added multiple times and cause an error
if (!responseHeaders.ContainsKey(AuthorisationConstants.AccessControlAllowOrigin))
{
responseHeaders.Add(AuthorisationConstants.AccessControlAllowOrigin, new[] { allowedOrigin });
}
TManager userManager;
try
{
userManager = context.OwinContext.GetUserManager<TManager>();
}
catch (Exception x)
{
Log.Error($"Failed to load UserManager of type {typeof(TManager).FullName} from the owin context.", x);
RejectWithServerError(context);
return;
}
TUser user;
try
{
user = await userManager.FindAsync(context.UserName, context.Password);
}
catch (Exception x)
{
Log.Error("Error reading user information from UserManager.", x);
RejectWithServerError(context);
return;
}
if (user != null && !IsUserInactiveOrLockedOut(user))
{
var identity = await userManager.CreateIdentityAsync(user, context.Options.AuthenticationType);
var props = new AuthenticationProperties(new Dictionary<string, string>
{
{
AuthorisationConstants.ClientId, context.ClientId ?? string.Empty
},
{
AuthorisationConstants.Username, context.UserName
}
});
var ticket = new AuthenticationTicket(identity, props);
context.Validated(ticket);
}
else
{
Log.Warning($"Failed Authentication Request: {GetRequestInfo(context)}");
RejectWithInvalidUserIdOrPassword(context);
}
}
/// <summary>
/// Check if user account is inactive or locked out
/// </summary>
protected virtual bool IsUserInactiveOrLockedOut(TUser user)
{
var uiUser = user as IUIUser;
return uiUser != null ? (!uiUser.IsApproved || uiUser.IsLockedOut) : false;
}
/// <summary>
/// Called when a request to the Token endpoint arrives with a "grant_type" of "refresh_token"
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public override Task GrantRefreshToken(OAuthGrantRefreshTokenContext context)
{
var originalClient = context.Ticket.Properties.Dictionary[AuthorisationConstants.ClientId];
var currentClient = context.ClientId;
if (originalClient != currentClient)
{
context.Rejected();
context.SetError(OAuthErrors.InvalidClientId, $"Refresh token is not valid for client '{context.ClientId}'");
return Completed();
}
// Change auth ticket for refresh token requests
var newIdentity = new ClaimsIdentity(context.Ticket.Identity);
var newTicket = new AuthenticationTicket(newIdentity, context.Ticket.Properties);
context.Validated(newTicket);
return Completed();
}
public override Task TokenEndpoint(OAuthTokenEndpointContext context)
{
foreach (KeyValuePair<string, string> property in context.Properties.Dictionary)
{
context.AdditionalResponseParameters.Add(property.Key, property.Value);
}
return Completed();
}
protected Task Completed()
{
var source = new TaskCompletionSource<object>();
source.SetResult(null);
return source.Task;
}
protected void RejectWithInvalidUserIdOrPassword(OAuthGrantResourceOwnerCredentialsContext context)
{
context.SetError(OAuthErrors.InvalidCredentials, "Invalid username or password, or the user account is inactive/locked out");
}
/// <summary>
/// Sets the context to rejected with a "server_error" error.
/// </summary>
/// <param name="context">The resource grant context.</param>
protected void RejectWithServerError(OAuthGrantResourceOwnerCredentialsContext context)
{
context.SetError(OAuthErrors.ServerError);
}
/// <summary>
/// Returns request header values.
/// </summary>
/// <param name="context">The context.</param>
/// <returns>A comma separated list of key:value pairs.</returns>
protected string GetRequestInfo(OAuthGrantResourceOwnerCredentialsContext context)
{
if (context.Request != null)
{
return string.Join(", ",
Enumerable.Select<KeyValuePair<string, string[]>, string>(context.Request.Headers,
h => h.Key + ":" + string.Join(", ", h.Value)));
}
return string.Empty;
}
}
DistributedRefreshTokenProvider
The main update to the refresh token provider was to update the CreateAsync method to acquire a new distributed lock and wrap the code to add the token with the lock.
public class DistributedRefreshTokenProvider : IAuthenticationTokenProvider
{
private static readonly object _lock = new object();
private readonly ContentApiOAuthOptions _options;
private readonly IRefreshTokenRepository _refreshTokenRepository;
/// <summary>
/// Initializes a new instance of the <see cref="CustomRefreshTokenProvider"/> class.
/// </summary>
/// <param name="options"></param>
/// <param name="refreshTokenRepository"></param>
public DistributedRefreshTokenProvider(ContentApiOAuthOptions options, IRefreshTokenRepository refreshTokenRepository)
{
_options = options;
_refreshTokenRepository = refreshTokenRepository;
}
/// <summary>
/// Creates a RefreshToken, storing it in the <see cref="IRefreshTokenRepository"/> and attaching it to the provided <see cref="AuthenticationTokenCreateContext"/>,
/// </summary>
/// <param name="context">AuthenticationTokenCreateContext to base the ticket creation</param>
/// <returns></returns>
public Task CreateAsync(AuthenticationTokenCreateContext context)
{
var clientid = context.Ticket.Properties.Dictionary[AuthorisationConstants.ClientId];
if (string.IsNullOrEmpty(clientid))
{
return Task.FromResult<object>(null);
}
var refreshTokenValue = Guid.NewGuid().ToString("n");
var hashRefreshTokenValue = GetHash(refreshTokenValue);
var refreshTokenLifeTime = _options.RefreshTokenExpireTimeSpan;
var newToken = _refreshTokenRepository.CreateToken(hashRefreshTokenValue,
clientid,
context.Ticket.Identity.Name,
DateTime.UtcNow,
DateTime.UtcNow.Add(refreshTokenLifeTime));
// When there are concurrent requests from the same identity sent to OAuth Server, two refresh tokens might be created for an user.
// Lock here can ensure that this case will not happen.
// Distributed lock is used to prevent multiple refresh tokens in a load balanced setup
var _lock = new SqlDistributedLock("OAuthRefreshToken", ConfigurationManager.ConnectionStrings["EPiServerDB"].ConnectionString);
using (_lock.Acquire())
{
context.Ticket.Properties.IssuedUtc = newToken.IssuedUtc;
context.Ticket.Properties.ExpiresUtc = newToken.ExpiresUtc;
newToken.ProtectedTicket = context.SerializeTicket();
var result = _refreshTokenRepository.Add(newToken);
if (result != Guid.Empty)
{
context.SetToken(refreshTokenValue);
}
}
return Task.FromResult<object>(null);
}
/// <summary>
/// Looks up a refesh token provided in the <see cref="AuthenticationTokenReceiveContext"/> and deserializes it, if found
/// </summary>
/// <param name="context">AuthenticationTokenReceiveContext to locate the ticket</param>
/// <returns></returns>
public Task ReceiveAsync(AuthenticationTokenReceiveContext context)
{
var allowedOrigin = context.OwinContext.Get<string>(AuthorisationConstants.AllowedOrigin);
var clientId = context.OwinContext.Get<string>(AuthorisationConstants.ClientId);
var responseHeaders = context.OwinContext.Response.Headers;
// Only add CORS header here if it does not exist. Otherwise in case client setup the CORS middleware along with OAuth,
// the header will be added multiple times and cause an error
if (!responseHeaders.ContainsKey(AuthorisationConstants.AccessControlAllowOrigin))
{
responseHeaders.Add(AuthorisationConstants.AccessControlAllowOrigin, new[] { allowedOrigin });
}
var hashTokenValue = GetHash(context.Token);
var refreshToken = _refreshTokenRepository.FindByValue(hashTokenValue);
if (refreshToken != null)
{
context.DeserializeTicket(refreshToken.ProtectedTicket);
if (refreshToken.ClientId == clientId)
{
_refreshTokenRepository.Remove(refreshToken);
}
}
return Task.FromResult<object>(null);
}
public void Create(AuthenticationTokenCreateContext context)
{
throw new NotImplementedException();
}
public void Receive(AuthenticationTokenReceiveContext context)
{
throw new NotImplementedException();
}
private string GetHash(string input)
{
HashAlgorithm hashAlgorithm = new SHA256CryptoServiceProvider();
byte[] byteValue = System.Text.Encoding.UTF8.GetBytes(input);
byte[] byteHash = hashAlgorithm.ComputeHash(byteValue);
return Convert.ToBase64String(byteHash);
}
}
AuthorizationAppBuilderExtension.cs
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Ut elit tellus, luctus nec ullamcorper mattis, pulvinar dapibus leo.
public static class AuthorizationAppBuilderExtension
{
public static IAppBuilder UseCustomIdentityOAuthAuthorization<TManager, TUser>(this IAppBuilder app, ContentApiOAuthOptions oAuthOptions)
where TManager : UserManager<TUser, string>
where TUser : IdentityUser, IUIUser, new()
{
var oAuthServerOptions = new OAuthAuthorizationServerOptions()
{
AllowInsecureHttp = !oAuthOptions.RequireSsl,
TokenEndpointPath = new PathString(oAuthOptions.TokenEndpointPath),
Provider = new CustomAuthorizationServerProvider<TManager, TUser, string>(),
RefreshTokenProvider = new DistributedRefreshTokenProvider(oAuthOptions, ServiceLocator.Current.GetInstance<IRefreshTokenRepository>())
};
if (oAuthOptions.AccessTokenExpireTimeSpan.HasValue)
{
oAuthServerOptions.AccessTokenExpireTimeSpan = oAuthOptions.AccessTokenExpireTimeSpan.Value;
}
// Token Generation
app.UseOAuthAuthorizationServer(oAuthServerOptions);
app.UseOAuthBearerAuthentication(new OAuthBearerAuthenticationOptions());
return app;
}
}
Update Startup
Update startup to call our new extension method.
//Set Requires SSL to false for local development
app.UseCustomIdentityOAuthAuthorization< ApplicationUserManager<SiteUser>, SiteUser>(new ContentApiOAuthOptions()
{
RequireSsl = false
});
And that is it. After pushing this to production we no longer had any issues with duplicate tokens.