SEBWIN-572: Improved stability of SEB Server connection by automatically updating OAuth2 token if it expires.

This commit is contained in:
Damian Büchel 2022-07-27 15:26:42 +02:00
parent 3277892ff2
commit 2875eb4c94
2 changed files with 126 additions and 57 deletions

View file

@ -29,6 +29,25 @@ namespace SafeExamBrowser.Server
this.logger = logger; this.logger = logger;
} }
internal bool IsTokenExpired(HttpContent content)
{
var isExpired = false;
try
{
var json = JsonConvert.DeserializeObject(Extract(content)) as JObject;
var error = json["error"].Value<string>();
isExpired = error?.Equals("invalid_token", StringComparison.OrdinalIgnoreCase) == true;
}
catch (Exception e)
{
logger.Error("Failed to parse token expiration content!", e);
}
return isExpired;
}
internal bool TryParseApi(HttpContent content, out ApiVersion1 api) internal bool TryParseApi(HttpContent content, out ApiVersion1 api)
{ {
var success = false; var success = false;
@ -88,7 +107,7 @@ namespace SafeExamBrowser.Server
internal bool TryParseConnectionToken(HttpResponseMessage response, out string connectionToken) internal bool TryParseConnectionToken(HttpResponseMessage response, out string connectionToken)
{ {
connectionToken = default(string); connectionToken = default;
try try
{ {
@ -108,7 +127,7 @@ namespace SafeExamBrowser.Server
logger.Error("Failed to parse connection token!", e); logger.Error("Failed to parse connection token!", e);
} }
return connectionToken != default(string); return connectionToken != default;
} }
internal bool TryParseExams(HttpContent content, out IList<Exam> exams) internal bool TryParseExams(HttpContent content, out IList<Exam> exams)
@ -141,8 +160,8 @@ namespace SafeExamBrowser.Server
internal bool TryParseInstruction(HttpContent content, out Attributes attributes, out string instruction, out string instructionConfirmation) internal bool TryParseInstruction(HttpContent content, out Attributes attributes, out string instruction, out string instructionConfirmation)
{ {
attributes = new Attributes(); attributes = new Attributes();
instruction = default(string); instruction = default;
instructionConfirmation = default(string); instructionConfirmation = default;
try try
{ {
@ -170,7 +189,25 @@ namespace SafeExamBrowser.Server
logger.Error("Failed to parse instruction!", e); logger.Error("Failed to parse instruction!", e);
} }
return instruction != default(string); return instruction != default;
}
internal bool TryParseOauth2Token(HttpContent content, out string oauth2Token)
{
oauth2Token = default;
try
{
var json = JsonConvert.DeserializeObject(Extract(content)) as JObject;
oauth2Token = json["access_token"].Value<string>();
}
catch (Exception e)
{
logger.Error("Failed to parse Oauth2 token!", e);
}
return oauth2Token != default;
} }
private Attributes ParseProctoringAttributes(JObject attributesJson, string instruction) private Attributes ParseProctoringAttributes(JObject attributesJson, string instruction)
@ -261,24 +298,6 @@ namespace SafeExamBrowser.Server
} }
} }
internal bool TryParseOauth2Token(HttpContent content, out string oauth2Token)
{
oauth2Token = default(string);
try
{
var json = JsonConvert.DeserializeObject(Extract(content)) as JObject;
oauth2Token = json["access_token"].Value<string>();
}
catch (Exception e)
{
logger.Error("Failed to parse Oauth2 token!", e);
}
return oauth2Token != default(string);
}
private string Extract(HttpContent content) private string Extract(HttpContent content)
{ {
var task = Task.Run(async () => var task = Task.Run(async () =>

View file

@ -10,6 +10,7 @@ using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Net;
using System.Net.Http; using System.Net.Http;
using System.Net.Http.Headers; using System.Net.Http.Headers;
using System.Text; using System.Text;
@ -95,23 +96,7 @@ namespace SafeExamBrowser.Server
if (success && parser.TryParseApi(response.Content, out api)) if (success && parser.TryParseApi(response.Content, out api))
{ {
logger.Info("Successfully loaded server API."); logger.Info("Successfully loaded server API.");
success = TryRetrieveOAuth2Token(out message);
var secret = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{settings.ClientName}:{settings.ClientSecret}"));
var authorization = ("Authorization", $"Basic {secret}");
var content = "grant_type=client_credentials&scope=read write";
var contentType = "application/x-www-form-urlencoded";
success = TryExecute(HttpMethod.Post, api.AccessTokenEndpoint, out response, content, contentType, authorization);
message = response.ToLogString();
if (success && parser.TryParseOauth2Token(response.Content, out oauth2Token))
{
logger.Info("Successfully retrieved OAuth2 token.");
}
else
{
logger.Error("Failed to retrieve OAuth2 token!");
}
} }
else else
{ {
@ -543,6 +528,28 @@ namespace SafeExamBrowser.Server
} }
} }
private bool TryRetrieveOAuth2Token(out string message)
{
var secret = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{settings.ClientName}:{settings.ClientSecret}"));
var authorization = ("Authorization", $"Basic {secret}");
var content = "grant_type=client_credentials&scope=read write";
var contentType = "application/x-www-form-urlencoded";
var success = TryExecute(HttpMethod.Post, api.AccessTokenEndpoint, out var response, content, contentType, authorization);
message = response.ToLogString();
if (success && parser.TryParseOauth2Token(response.Content, out oauth2Token))
{
logger.Info("Successfully retrieved OAuth2 token.");
}
else
{
logger.Error("Failed to retrieve OAuth2 token!");
}
return success;
}
private bool TryExecute( private bool TryExecute(
HttpMethod method, HttpMethod method,
string url, string url,
@ -555,22 +562,7 @@ namespace SafeExamBrowser.Server
for (var attempt = 0; attempt < settings.RequestAttempts && (response == default || !response.IsSuccessStatusCode); attempt++) for (var attempt = 0; attempt < settings.RequestAttempts && (response == default || !response.IsSuccessStatusCode); attempt++)
{ {
var request = new HttpRequestMessage(method, url); var request = BuildRequest(method, url, content, contentType, headers);
if (content != default)
{
request.Content = new StringContent(content, Encoding.UTF8);
if (contentType != default)
{
request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(contentType);
}
}
foreach (var (name, value) in headers)
{
request.Headers.Add(name, value);
}
try try
{ {
@ -580,6 +572,16 @@ namespace SafeExamBrowser.Server
{ {
logger.Debug($"Completed request: {request.Method} '{request.RequestUri}' -> {response.ToLogString()}"); logger.Debug($"Completed request: {request.Method} '{request.RequestUri}' -> {response.ToLogString()}");
} }
if (response.StatusCode == HttpStatusCode.Unauthorized && parser.IsTokenExpired(response.Content))
{
logger.Info("OAuth2 token has expired, attempting to retrieve new one...");
if (TryRetrieveOAuth2Token(out var message))
{
headers = UpdateOAuth2Token(headers);
}
}
} }
catch (TaskCanceledException) catch (TaskCanceledException)
{ {
@ -594,5 +596,53 @@ namespace SafeExamBrowser.Server
return response != default && response.IsSuccessStatusCode; return response != default && response.IsSuccessStatusCode;
} }
private HttpRequestMessage BuildRequest(
HttpMethod method,
string url,
string content = default,
string contentType = default,
params (string name, string value)[] headers)
{
var request = new HttpRequestMessage(method, url);
if (content != default)
{
request.Content = new StringContent(content, Encoding.UTF8);
if (contentType != default)
{
request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(contentType);
}
}
request.Headers.Add("Accept", "application/json, */*");
foreach (var (name, value) in headers)
{
request.Headers.Add(name, value);
}
return request;
}
private (string name, string value)[] UpdateOAuth2Token((string name, string value)[] headers)
{
var result = new List<(string name, string value)>();
foreach (var header in headers)
{
if (header.name == "Authorization")
{
result.Add(("Authorization", $"Bearer {oauth2Token}"));
}
else
{
result.Add(header);
}
}
return result.ToArray();
}
} }
} }