From f164086490d31ba94f9ee27b70f5dcdd00a7ae1f Mon Sep 17 00:00:00 2001 From: Patchzy <64382339+patchzyy@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:42:35 +0200 Subject: [PATCH] Sanitize download filenames and prevent traversal --- WheelWizard/Helpers/DownloadHelper.cs | 35 +++++++++++++++++-- .../Services/Installation/ModInstallation.cs | 7 ++-- WheelWizard/Views/App.axaml.cs | 24 ++++++++----- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/WheelWizard/Helpers/DownloadHelper.cs b/WheelWizard/Helpers/DownloadHelper.cs index ccfc874c..007800c7 100644 --- a/WheelWizard/Helpers/DownloadHelper.cs +++ b/WheelWizard/Helpers/DownloadHelper.cs @@ -78,8 +78,9 @@ public static class DownloadHelper // Check for filename in Content-Disposition or fallback to URL var contentDisposition = response.Content.Headers.ContentDisposition; - var fileName = contentDisposition?.FileName?.Trim('"') ?? Path.GetFileName(new Uri(url).AbsolutePath); - fileName = Path.ChangeExtension(fileName, Path.GetExtension(finalUrl)); + var fileName = + contentDisposition?.FileNameStar ?? contentDisposition?.FileName ?? Path.GetFileName(new Uri(url).AbsolutePath); + fileName = GetSafeDownloadFileName(fileName, finalUrl, url); // Add extension if missing in file path if (!Path.HasExtension(fileName)) @@ -93,6 +94,7 @@ public static class DownloadHelper // Update resolvedFilePath with resolved fileName resolvedFilePath = Path.Combine(directory, fileName); + EnsurePathStaysWithinDirectory(resolvedFilePath, directory); } var totalBytes = response.Content.Headers.ContentLength ?? -1; @@ -199,4 +201,33 @@ public static class DownloadHelper progressPopupWindow.SetCancellationTokenSource(null); } } + + private static string GetSafeDownloadFileName(string fileName, string finalUrl, string originalUrl) + { + var trimmedName = fileName.Trim().Trim('"'); + if (string.IsNullOrWhiteSpace(trimmedName)) + { + trimmedName = Path.GetFileName(new Uri(originalUrl).AbsolutePath); + } + + // Only allow a basename from remote input, never a relative or absolute path. + var safeFileName = Path.GetFileName(trimmedName.Replace('\\', '/')); + if (string.IsNullOrWhiteSpace(safeFileName)) + throw new InvalidOperationException("The server returned an invalid download filename."); + + var finalExtension = Path.GetExtension(finalUrl); + if (!string.IsNullOrWhiteSpace(finalExtension)) + safeFileName = Path.ChangeExtension(safeFileName, finalExtension); + + return safeFileName; + } + + private static void EnsurePathStaysWithinDirectory(string path, string directory) + { + var normalizedDirectory = Path.GetFullPath(directory + Path.DirectorySeparatorChar); + var normalizedPath = Path.GetFullPath(path); + var comparison = OperatingSystem.IsWindows() ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal; + if (!normalizedPath.StartsWith(normalizedDirectory, comparison)) + throw new InvalidOperationException("The download path escaped the target directory."); + } } diff --git a/WheelWizard/Services/Installation/ModInstallation.cs b/WheelWizard/Services/Installation/ModInstallation.cs index 5662bf8a..9b2dcdae 100644 --- a/WheelWizard/Services/Installation/ModInstallation.cs +++ b/WheelWizard/Services/Installation/ModInstallation.cs @@ -109,6 +109,7 @@ public static bool ModExists(ObservableCollection mods, string modName) => public static void ProcessFile(string file, string destinationDirectory, ProgressWindow progressWindow) { var extension = Path.GetExtension(file).ToLowerInvariant(); + var normalizedDestinationDirectory = Path.GetFullPath(destinationDirectory + Path.DirectorySeparatorChar); if (!Directory.Exists(destinationDirectory)) Directory.CreateDirectory(destinationDirectory); @@ -144,12 +145,10 @@ public static void ProcessFile(string file, string destinationDirectory, Progres ); var entryDestinationPath = Path.Combine(destinationDirectory, sanitizedKey); + var normalizedEntryDestinationPath = Path.GetFullPath(entryDestinationPath); // Ensure the entry destination path is within the destination directory - if ( - !Path.GetFullPath(entryDestinationPath) - .StartsWith(Path.GetFullPath(destinationDirectory), StringComparison.OrdinalIgnoreCase) - ) + if (!normalizedEntryDestinationPath.StartsWith(normalizedDestinationDirectory, StringComparison.OrdinalIgnoreCase)) { throw new UnauthorizedAccessException("Entry is attempting to extract outside of the destination directory."); } diff --git a/WheelWizard/Views/App.axaml.cs b/WheelWizard/Views/App.axaml.cs index bd1c454a..f17de132 100644 --- a/WheelWizard/Views/App.axaml.cs +++ b/WheelWizard/Views/App.axaml.cs @@ -5,11 +5,11 @@ using Microsoft.Extensions.Logging; using WheelWizard.AutoUpdating; using WheelWizard.MiiRendering.Services; -using WheelWizard.Settings; -using WheelWizard.Services.Launcher; using WheelWizard.Services; +using WheelWizard.Services.Launcher; using WheelWizard.Services.LiveData; using WheelWizard.Services.UrlProtocol; +using WheelWizard.Settings; using WheelWizard.Views.Behaviors; using WheelWizard.Views.Popups.Generic; using WheelWizard.WheelWizardData; @@ -79,15 +79,19 @@ private static StartupLaunchTarget GetStartupLaunchTarget() for (var i = 1; i < args.Length; i++) { var argument = args[i]; - if (argument.Equals("--launch", StringComparison.OrdinalIgnoreCase) || argument.Equals("-l", StringComparison.OrdinalIgnoreCase)) + if ( + argument.Equals("--launch", StringComparison.OrdinalIgnoreCase) || argument.Equals("-l", StringComparison.OrdinalIgnoreCase) + ) { if (i + 1 >= args.Length) continue; var launchTarget = args[++i]; - if (launchTarget.Equals("rr", StringComparison.OrdinalIgnoreCase) || - launchTarget.Equals("retrorewind", StringComparison.OrdinalIgnoreCase) || - launchTarget.Equals("retro-rewind", StringComparison.OrdinalIgnoreCase)) + if ( + launchTarget.Equals("rr", StringComparison.OrdinalIgnoreCase) + || launchTarget.Equals("retrorewind", StringComparison.OrdinalIgnoreCase) + || launchTarget.Equals("retro-rewind", StringComparison.OrdinalIgnoreCase) + ) { return StartupLaunchTarget.RetroRewind; } @@ -99,9 +103,11 @@ private static StartupLaunchTarget GetStartupLaunchTarget() continue; var launchTargetFromEquals = argument["--launch=".Length..].Trim(); - if (launchTargetFromEquals.Equals("rr", StringComparison.OrdinalIgnoreCase) || - launchTargetFromEquals.Equals("retrorewind", StringComparison.OrdinalIgnoreCase) || - launchTargetFromEquals.Equals("retro-rewind", StringComparison.OrdinalIgnoreCase)) + if ( + launchTargetFromEquals.Equals("rr", StringComparison.OrdinalIgnoreCase) + || launchTargetFromEquals.Equals("retrorewind", StringComparison.OrdinalIgnoreCase) + || launchTargetFromEquals.Equals("retro-rewind", StringComparison.OrdinalIgnoreCase) + ) { return StartupLaunchTarget.RetroRewind; }