diff --git a/seed/tests/test_media.py b/seed/tests/test_media.py index 09eef6fb60..9427f8f89e 100644 --- a/seed/tests/test_media.py +++ b/seed/tests/test_media.py @@ -23,8 +23,10 @@ class TestMeasures(TestCase): def setUp(self): - self.user_a = User.objects.create(username="user_a") - self.user_b = User.objects.create(username="user_b") + self.user_details_a = {"username": "user_a@test.com", "password": "test_pass_a", "email": "user_a@test.com"} + self.user_a = User.objects.create_user(**self.user_details_a) + self.user_details_b = {"username": "user_b@test.com", "password": "test_pass_b", "email": "user_b@test.com"} + self.user_b = User.objects.create_user(**self.user_details_b) self.org_a = Organization.objects.create() self.root_a = AccessLevelInstance.objects.get(organization_id=self.org_a, depth=1) self.org_a_sub = Organization.objects.create() @@ -303,3 +305,68 @@ def test_fails_when_path_does_not_match(self): # test bad path with pytest.raises(ModelForFileNotFoundError): check_file_permission(self.user_a, "") + + def test_retrieve_file_successfully(self): + """Test that retrieve endpoint serves files correctly""" + # Setup - create an import file + import_record = ImportRecord.objects.create( + owner=self.user_a, last_modified_by=self.user_a, super_organization=self.org_a, access_level_instance=self.org_a.root + ) + ImportFile.objects.create( + import_record=import_record, uploaded_filename=os.path.basename(self.uploads_file), file=self.absolute_uploads_file + ) + + self.client.login(**self.user_details_a) + response = self.client.get(f"/api/v3/media/{self.uploads_file}") + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"Hello world") + self.assertIn("Content-Disposition", response) + self.assertIn("attachment", response["Content-Disposition"]) + + def test_retrieve_file_with_comma_in_filename(self): + """Test that files with commas in filename download correctly""" + # Setup - create a file with commas in the name + absolute_comma_file = get_upload_path("test,file,with,commas.txt") + comma_file = os.path.relpath(absolute_comma_file, settings.MEDIA_ROOT) + os.makedirs(os.path.dirname(absolute_comma_file), exist_ok=True) + with open(absolute_comma_file, "w", encoding=locale.getpreferredencoding(False)) as f: + f.write("File with commas") + + import_record = ImportRecord.objects.create( + owner=self.user_a, last_modified_by=self.user_a, super_organization=self.org_a, access_level_instance=self.org_a.root + ) + ImportFile.objects.create(import_record=import_record, uploaded_filename=os.path.basename(comma_file), file=absolute_comma_file) + + self.client.login(**self.user_details_a) + response = self.client.get(f"/api/v3/media/{comma_file}") + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"File with commas") + # Verify the download filename is sanitized (commas removed) + self.assertIn("Content-Disposition", response) + self.assertIn("testfilewithcommas.txt", response["Content-Disposition"]) + self.assertNotIn(",", response["Content-Disposition"]) + + def test_retrieve_file_without_permission_returns_404(self): + """Test that users without permission get 404""" + # Setup - create an import file for org_a + import_record = ImportRecord.objects.create( + owner=self.user_a, last_modified_by=self.user_a, super_organization=self.org_a, access_level_instance=self.org_a.root + ) + ImportFile.objects.create( + import_record=import_record, uploaded_filename=os.path.basename(self.uploads_file), file=self.absolute_uploads_file + ) + + # user_b tries to access org_a's file + self.client.login(**self.user_details_b) + response = self.client.get(f"/api/v3/media/{self.uploads_file}") + + self.assertEqual(response.status_code, 404) + + def test_retrieve_nonexistent_file_returns_404(self): + """Test that requesting non-existent file returns 404""" + self.client.login(**self.user_details_a) + response = self.client.get("/api/v3/media/uploads/nonexistent.txt") + + self.assertEqual(response.status_code, 404) diff --git a/seed/views/v3/media.py b/seed/views/v3/media.py index 207a92d852..13db56ac8a 100644 --- a/seed/views/v3/media.py +++ b/seed/views/v3/media.py @@ -4,12 +4,13 @@ """ import logging +import mimetypes import os -import re from django.conf import settings from django.http import HttpResponse from django.utils.decorators import method_decorator +from django.utils.text import get_valid_filename from rest_framework import generics from seed.models import Analysis, AnalysisOutputFile, BuildingFile, ImportFile, InventoryDocument, Organization @@ -114,20 +115,25 @@ def retrieve(self, request, filepath): logger.debug(f"Failed to locate organization for file: {e!s}") return HttpResponse(status=404) - if user_has_permission: - # Attempt to remove NamedTemporaryFile suffix - filename = os.path.basename(filepath) - name, ext = os.path.splitext(filename) - pattern = re.compile("(.*?)(_[a-zA-Z0-9]{7})$") - match = pattern.match(name) - if match: - filename = match.groups()[0] + ext - - response = HttpResponse() - if ext != ".html": - response["Content-Disposition"] = f"attachment; filename={filename}" - response["X-Accel-Redirect"] = f"/protected/{filepath}" - return response - else: - # 404 instead of 403 to avoid leaking information + if not user_has_permission: return HttpResponse(status=404) + + filename = os.path.basename(filepath) + ext = os.path.splitext(filename)[1] + absolute_filepath = os.path.join(settings.MEDIA_ROOT, filepath) + + if not os.path.exists(absolute_filepath): + return HttpResponse(status=404) + + # Serve file through Django + with open(absolute_filepath, "rb") as f: + file_data = f.read() + + content_type, _ = mimetypes.guess_type(filename) + response = HttpResponse(file_data, content_type=content_type or "application/octet-stream") + + if ext != ".html": + safe_download_name = get_valid_filename(filename) + response["Content-Disposition"] = f'attachment; filename="{safe_download_name}"' + + return response diff --git a/seed/views/v3/uploads.py b/seed/views/v3/uploads.py index b53697edcd..b5f215a168 100644 --- a/seed/views/v3/uploads.py +++ b/seed/views/v3/uploads.py @@ -14,6 +14,7 @@ from django.core.files.storage import FileSystemStorage from django.http import JsonResponse from django.utils.decorators import method_decorator +from django.utils.text import get_valid_filename from drf_yasg.utils import no_body, swagger_auto_schema from rest_framework import status, viewsets from rest_framework.decorators import action @@ -90,7 +91,8 @@ def create(self, request): the_file = request.data["qqfile"] else: the_file = request.data["file"] - filename = the_file.name + # Sanitize filename to remove problematic characters (commas, semicolons, etc.) + filename = get_valid_filename(the_file.name) path = get_upload_path(filename) # verify the directory exists