Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions seed/tests/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@

class TestMeasures(TestCase):
def setUp(self):
self.user_a = User.objects.create(username="user_a")
self.user_b = User.objects.create(username="user_b")
self.user_details_a = {"username": "user_a@test.com", "password": "test_pass_a", "email": "user_a@test.com"}
self.user_a = User.objects.create_user(**self.user_details_a)
self.user_details_b = {"username": "user_b@test.com", "password": "test_pass_b", "email": "user_b@test.com"}
self.user_b = User.objects.create_user(**self.user_details_b)
self.org_a = Organization.objects.create()
self.root_a = AccessLevelInstance.objects.get(organization_id=self.org_a, depth=1)
self.org_a_sub = Organization.objects.create()
Expand Down Expand Up @@ -303,3 +305,68 @@ def test_fails_when_path_does_not_match(self):
# test bad path
with pytest.raises(ModelForFileNotFoundError):
check_file_permission(self.user_a, "")

def test_retrieve_file_successfully(self):
"""Test that retrieve endpoint serves files correctly"""
# Setup - create an import file
import_record = ImportRecord.objects.create(
owner=self.user_a, last_modified_by=self.user_a, super_organization=self.org_a, access_level_instance=self.org_a.root
)
ImportFile.objects.create(
import_record=import_record, uploaded_filename=os.path.basename(self.uploads_file), file=self.absolute_uploads_file
)

self.client.login(**self.user_details_a)
response = self.client.get(f"/api/v3/media/{self.uploads_file}")

self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"Hello world")
self.assertIn("Content-Disposition", response)
self.assertIn("attachment", response["Content-Disposition"])

def test_retrieve_file_with_comma_in_filename(self):
"""Test that files with commas in filename download correctly"""
# Setup - create a file with commas in the name
absolute_comma_file = get_upload_path("test,file,with,commas.txt")
comma_file = os.path.relpath(absolute_comma_file, settings.MEDIA_ROOT)
os.makedirs(os.path.dirname(absolute_comma_file), exist_ok=True)
with open(absolute_comma_file, "w", encoding=locale.getpreferredencoding(False)) as f:
f.write("File with commas")

import_record = ImportRecord.objects.create(
owner=self.user_a, last_modified_by=self.user_a, super_organization=self.org_a, access_level_instance=self.org_a.root
)
ImportFile.objects.create(import_record=import_record, uploaded_filename=os.path.basename(comma_file), file=absolute_comma_file)

self.client.login(**self.user_details_a)
response = self.client.get(f"/api/v3/media/{comma_file}")

self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"File with commas")
# Verify the download filename is sanitized (commas removed)
self.assertIn("Content-Disposition", response)
self.assertIn("testfilewithcommas.txt", response["Content-Disposition"])
self.assertNotIn(",", response["Content-Disposition"])

def test_retrieve_file_without_permission_returns_404(self):
"""Test that users without permission get 404"""
# Setup - create an import file for org_a
import_record = ImportRecord.objects.create(
owner=self.user_a, last_modified_by=self.user_a, super_organization=self.org_a, access_level_instance=self.org_a.root
)
ImportFile.objects.create(
import_record=import_record, uploaded_filename=os.path.basename(self.uploads_file), file=self.absolute_uploads_file
)

# user_b tries to access org_a's file
self.client.login(**self.user_details_b)
response = self.client.get(f"/api/v3/media/{self.uploads_file}")

self.assertEqual(response.status_code, 404)

def test_retrieve_nonexistent_file_returns_404(self):
"""Test that requesting non-existent file returns 404"""
self.client.login(**self.user_details_a)
response = self.client.get("/api/v3/media/uploads/nonexistent.txt")

self.assertEqual(response.status_code, 404)
40 changes: 23 additions & 17 deletions seed/views/v3/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
"""

import logging
import mimetypes
import os
import re

from django.conf import settings
from django.http import HttpResponse
from django.utils.decorators import method_decorator
from django.utils.text import get_valid_filename
from rest_framework import generics

from seed.models import Analysis, AnalysisOutputFile, BuildingFile, ImportFile, InventoryDocument, Organization
Expand Down Expand Up @@ -114,20 +115,25 @@ def retrieve(self, request, filepath):
logger.debug(f"Failed to locate organization for file: {e!s}")
return HttpResponse(status=404)

if user_has_permission:
# Attempt to remove NamedTemporaryFile suffix
filename = os.path.basename(filepath)
name, ext = os.path.splitext(filename)
pattern = re.compile("(.*?)(_[a-zA-Z0-9]{7})$")
match = pattern.match(name)
if match:
filename = match.groups()[0] + ext

response = HttpResponse()
if ext != ".html":
response["Content-Disposition"] = f"attachment; filename={filename}"
response["X-Accel-Redirect"] = f"/protected/{filepath}"
return response
else:
# 404 instead of 403 to avoid leaking information
if not user_has_permission:
return HttpResponse(status=404)

filename = os.path.basename(filepath)
ext = os.path.splitext(filename)[1]
absolute_filepath = os.path.join(settings.MEDIA_ROOT, filepath)

if not os.path.exists(absolute_filepath):
return HttpResponse(status=404)

# Serve file through Django
with open(absolute_filepath, "rb") as f:
file_data = f.read()

content_type, _ = mimetypes.guess_type(filename)
response = HttpResponse(file_data, content_type=content_type or "application/octet-stream")

if ext != ".html":
safe_download_name = get_valid_filename(filename)
response["Content-Disposition"] = f'attachment; filename="{safe_download_name}"'

return response
4 changes: 3 additions & 1 deletion seed/views/v3/uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from django.core.files.storage import FileSystemStorage
from django.http import JsonResponse
from django.utils.decorators import method_decorator
from django.utils.text import get_valid_filename
from drf_yasg.utils import no_body, swagger_auto_schema
from rest_framework import status, viewsets
from rest_framework.decorators import action
Expand Down Expand Up @@ -90,7 +91,8 @@ def create(self, request):
the_file = request.data["qqfile"]
else:
the_file = request.data["file"]
filename = the_file.name
# Sanitize filename to remove problematic characters (commas, semicolons, etc.)
filename = get_valid_filename(the_file.name)
path = get_upload_path(filename)

# verify the directory exists
Expand Down
Loading