diff --git a/tdd/__init__.py b/tdd/__init__.py index 350d0ee..07e3b8a 100644 --- a/tdd/__init__.py +++ b/tdd/__init__.py @@ -52,6 +52,7 @@ get_check_schema_from_url_params, ) from tdd.sparql import query, sparql_query +from .validators import validate_sort_order from tdd.utils import ( POSSIBLE_MIMETYPES, create_link_params, @@ -285,12 +286,15 @@ def describe_tds(): sort_by = request.args.get("sort_by") sort_order = request.args.get("sort_order") + if sort_order is not None: + sort_order = validate_sort_order(sort_order) number_total = get_total_number() sort_params = {} - if sort_order: - sort_params["sort_order"] = sort_order + if sort_order is not None: + # Use lowercase for URL parameters (API convention) + sort_params["sort_order"] = sort_order.lower() if sort_by: sort_params["sort_by"] = sort_by diff --git a/tdd/common.py b/tdd/common.py index 08fbfb2..7ee296c 100644 --- a/tdd/common.py +++ b/tdd/common.py @@ -112,6 +112,7 @@ def frame_nt_content(nt_content, frame): stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, + encoding="utf-8", ) p.stdin.write(input_data) p.stdin.flush() diff --git a/tdd/errors.py b/tdd/errors.py index 9499ede..f93d515 100644 --- a/tdd/errors.py +++ b/tdd/errors.py @@ -212,3 +212,15 @@ def __init__(self, provided_mimetype): class IncorrectlyDefinedParameter(AppException): title = "Incorrectly defined parameter" + + +class SecurityValidationError(AppException): + title = "Security Validation Error" + status_code = 400 + + def __init__(self, message="Malformed or unsafe input detected."): + super().__init__( + message=message, + message_fr="Entrée mal formée ou non sécurisée détectée.", + message_de="Fehlerhafte oder unsichere Eingabe erkannt.", + ) diff --git a/tdd/registration.py b/tdd/registration.py index 43fe7a5..39b9788 100644 --- a/tdd/registration.py +++ b/tdd/registration.py @@ -20,6 +20,7 @@ from tdd.errors import TTLMandatoryError from tdd.utils import TDD +from tdd.validators import validate_uri def validate_ttl(ld_content, mandate_ttl): @@ -30,11 +31,13 @@ def validate_ttl(ld_content, mandate_ttl): def get_registration_dict(uri, rdf_graph): + # Upstream validation: Secure the URI before placing it in the SPARQL query string + safe_uri = validate_uri(uri) registration_query = ( "PREFIX discovery: " "SELECT DISTINCT ?created ?modified ?expires ?ttl " "WHERE {" - f" <{uri}> discovery:hasRegistrationInformation ?reg." + f" <{safe_uri}> discovery:hasRegistrationInformation ?reg." " OPTIONAL{?reg discovery:dateCreated ?created}" " OPTIONAL{?reg discovery:dateModified ?modified}" " OPTIONAL{?reg discovery:expires ?expires}" @@ -66,7 +69,9 @@ def get_registration_dict(uri, rdf_graph): def delete_registration_information(uri, rdf_graph): - rdf_graph.remove((URIRef(uri), TDD.hasRegistrationInformation, None)) + # Sanitize before processing + safe_uri = validate_uri(uri) + rdf_graph.remove((URIRef(safe_uri), TDD.hasRegistrationInformation, None)) rdf_graph.remove((None, TDD.dateCreated, None)) rdf_graph.remove((None, TDD.dateModified, None)) rdf_graph.remove((None, TDD.expires, None)) diff --git a/tdd/sparql.py b/tdd/sparql.py index a58fee3..e171605 100644 --- a/tdd/sparql.py +++ b/tdd/sparql.py @@ -18,7 +18,6 @@ import atexit from flask import Response - from .config import CONFIG from .errors import FusekiError @@ -223,15 +222,15 @@ def query( if route != "": sparqlendpoint = urljoin(f"{sparqlendpoint}/", route) + if request_type == "query": # Utilize the global HTTP client for connection pooling. - # Note: SPARQL injection mitigation must be handled upstream by explicit input validators. resp = http_client.post( sparqlendpoint, data={"query": querystring}, headers=headers, ) - if request_type == "update": + elif request_type == "update": if CONFIG["ENDPOINT_TYPE"] == "GRAPHDB": sparqlendpoint = urljoin(f"{sparqlendpoint}/", "statements") # Utilize the global HTTP client for update operations to maintain low latency. @@ -239,6 +238,8 @@ def query( sparqlendpoint, data={"update": querystring}, ) + else: + raise ValueError(f"Invalid request_type: {request_type}") if resp.status_code not in status_codes: raise FusekiError(resp) @@ -246,4 +247,15 @@ def query( def delete_named_graph(named_graph): + """ + Delete a named graph from the SPARQL endpoint. + + Args: + named_graph: Graph URI to delete (from internal system, not user input) + + Note: + This function is called with graph URIs from internal database queries, + not from user input. No external validation is needed as these are + trusted internal values that already passed validation when stored. + """ query(f"DROP SILENT GRAPH <{named_graph}>", request_type="update") diff --git a/tdd/td.py b/tdd/td.py index bacba51..eef5ea3 100644 --- a/tdd/td.py +++ b/tdd/td.py @@ -70,6 +70,7 @@ frame_nt_content, get_id_description, ) +from .validators import validate_uri with files(__package__).joinpath("data/td-json-schema-validation.json").open() as strm: schema = json.load(strm) @@ -107,7 +108,7 @@ def use_custom_context(ld_content): # No need for now, since the published context is up to date overwrite_thing_context(ld_content) - # replace discovery context uri witht the fixed discovery context + # replace discovery context uri with the fixed discovery context overwrite_discovery_context(ld_content) return ld_content @@ -161,8 +162,10 @@ def validate_tds(tds): def get_already_existing_td(uri): + # Upstream validation: Ensure URI is safe before injecting into SPARQL template + safe_uri = validate_uri(uri) resp = query( - GET_TD_CREATION_DATE.format(uri=uri), + GET_TD_CREATION_DATE.format(uri=safe_uri), ) if resp.status_code == 200: if len(resp.json()["results"]["bindings"]) > 0: @@ -183,6 +186,8 @@ def put_td_rdf_in_sparql( if uri is None: raise RDFValidationError(f"Did not find any {TD['Thing']}") + safe_uri = validate_uri(uri) + if check_schema: ontology_graph = create_binded_graph() with path("tdd.data", "td.ttl") as onto_path: @@ -200,37 +205,38 @@ def put_td_rdf_in_sparql( raise RDFValidationError( "The RDF triples are not conform with the SHACL validation : \n" f" {text_reports}", - td_id=uri, + td_id=safe_uri, errors=graph_reports, td_graph=g, ) - registration = get_registration_dict(uri, g) - delete_registration_information(uri, g) + registration = get_registration_dict(safe_uri, g) + delete_registration_information(safe_uri, g) - created_date = get_already_existing_td(uri) + created_date = get_already_existing_td(safe_uri) registration = update_registration(registration, created_date, CONFIG["MAX_TTL"]) - for triple in yield_registration_triples(uri, registration): + for triple in yield_registration_triples(safe_uri, registration): g.add(triple) put_rdf_in_sparql( g, - uri, + safe_uri, [DEFAULT_THING_CONTEXT_URI, DEFAULT_DISCOVERY_CONTEXT_URI], delete_if_exists, ONTOLOGY, forced_type=TYPE, ) - return (created_date is not None, uri) + return (created_date is not None, safe_uri) def get_td_description(id, content_type="application/td+json", context=None): + safe_id = validate_uri(id) if not content_type.endswith("json"): - return get_id_description(id, content_type, ONTOLOGY) - content = get_id_description(id, "application/n-triples", ONTOLOGY) + return get_id_description(safe_id, content_type, ONTOLOGY) + content = get_id_description(safe_id, "application/n-triples", ONTOLOGY) if not context: - context = get_context(id, ONTOLOGY) + context = get_context(safe_id, ONTOLOGY) try: - td_description = frame_td_nt_content(id, content, context) + td_description = frame_td_nt_content(safe_id, content, context) return td_description except ExpireTDError: return "" @@ -245,7 +251,8 @@ def put_td_json_in_sparql(td_content, uri=None, delete_if_exists=True): registration = td_content.get("registration", {}) td_content = sanitize_td(td_content) original_context = copy(td_content["@context"]) - uri = uri if uri is not None else td_content["id"] + # Upstream validation: Sanitize the URI whether it comes from args or the payload ID + uri = validate_uri(uri if uri is not None else td_content["id"]) td_content = use_custom_context(td_content) created_date = get_already_existing_td(uri) @@ -260,6 +267,23 @@ def put_td_json_in_sparql(td_content, uri=None, delete_if_exists=True): def delete_graphs(ids): + """ + Delete multiple graphs by their IDs. + + Args: + ids: List of graph IDs to delete + + Note: + This function is called with IDs from internal database queries + (e.g., expired TDs from clear_expired_td()). These IDs are trusted + internal values, not user input, so no external validation is needed. + + Applying validate_uri() here would be incorrect because: + 1. These URIs already passed validation when originally stored + 2. Legitimate stored URIs might contain characters outside the strict + allowlist (e.g., certain URN formats) + 3. Validation should only occur at the trust boundary (user input) + """ graph_ids_str = ", ".join([f"<{graph_id}>" for graph_id in ids]) delete_td_query = DELETE_GRAPHS.format(graph_ids_str=graph_ids_str) resp = query(delete_td_query, request_type="update") @@ -322,18 +346,43 @@ def get_total_number(): def get_paginated_tds(limit, offset, sort_by, sort_order): - all_tds = [] + """ + Get a paginated list of Thing Descriptions. + + Args: + limit (int): Maximum number of TDs to return (pre-validated at controller layer) + offset (int): Offset for pagination (pre-validated at controller layer) + sort_by (str): Field to sort by (pre-validated at controller layer) + sort_order (str): Sort direction "ASC" or "DESC" (pre-validated at controller layer) + + Returns: + List[dict]: List of Thing Description dictionaries in the order specified by SPARQL query + + Note: + All parameters are assumed to be pre-validated and type-converted at the + controller layer (__init__.py). No redundant validation is performed here. + + Thread Safety: + Uses ThreadPoolExecutor for concurrent TD retrieval. Results are collected + in the main thread in the original task submission order to preserve the + SPARQL ORDER BY sequence. + """ tasks = [] def send_request(id, context): - td = get_td_description(id, context=context) - all_tds.append(td) + """ + Fetch a single TD description. + + Returns the TD instead of appending to a shared list for thread safety. + """ + return get_td_description(id, context=context) contexts = get_all_contexts() if sort_by is not None and sort_by not in ORDERBY: raise OrderbyError(sort_by) + # No redundant validation - parameters already validated in __init__.py resp = query( GET_URI_BY_ONTOLOGY.format( limit=limit, @@ -366,6 +415,10 @@ def send_request(id, context): contexts[result["graph"]["value"]], ) ) + # Wait for all tasks to complete in submission order to preserve SPARQL ORDER BY + all_tds = [] + for task in tasks: + all_tds.append(task.result()) return all_tds diff --git a/tdd/tests/test_validators.py b/tdd/tests/test_validators.py new file mode 100644 index 0000000..1c3ed40 --- /dev/null +++ b/tdd/tests/test_validators.py @@ -0,0 +1,495 @@ +"""****************************************************************************** +* Copyright (c) 2018 Contributors to the Eclipse Foundation +* +* See the NOTICE file(s) distributed with this work for additional +* information regarding copyright ownership. +* +* This program and the accompanying materials are made available under the +* terms of the Eclipse Public License v. 2.0 which is available at +* http://www.eclipse.org/legal/epl-2.0, or the W3C Software Notice and +* Document License (2015-05-13) which is available at +* https://www.w3.org/Consortium/Legal/2015/copyright-software-and-document. +* +* SPDX-License-Identifier: EPL-2.0 OR W3C-20150513 +******************************************************************************** + +Unit tests for security validators module. + +These tests ensure that the validation layer correctly blocks SPARQL injection +attempts while allowing legitimate URIs and parameters to pass through. +""" + +import pytest +from tdd.validators import validate_uri, validate_sort_order, validate_uris +from tdd.errors import SecurityValidationError + + +class TestValidateUri: + """Test suite for URI validation against SPARQL injection.""" + + def test_valid_http_uris(self): + """Test that valid HTTP/HTTPS URIs pass validation.""" + valid_uris = [ + "https://example.com/td/1", + "http://localhost:3030/things", + "https://www.w3.org/2019/wot/td", + "http://example.com:8080/path/to/resource", + ] + for uri in valid_uris: + assert validate_uri(uri) == uri + + def test_valid_urn_uris(self): + """Test that valid URN URIs pass validation.""" + valid_urns = [ + "urn:uuid:12345678-1234-5678-1234-567812345678", + "urn:dev:ops:my-thing-1234", + "urn:example:animal:ferret:nose", + ] + for urn in valid_urns: + assert validate_uri(urn) == urn + + def test_valid_percent_encoded_uris(self): + """Test that percent-encoded URIs pass validation.""" + valid_encoded = [ + "http://example.com/path%20with%20spaces", + "http://example.com/query?name=John%20Doe", + "urn:uuid:test%2Fslash", + ] + for uri in valid_encoded: + assert validate_uri(uri) == uri + + def test_uri_with_query_parameters(self): + """Test that URIs with query parameters pass validation.""" + uri = "http://example.com/path?query=value&foo=bar&baz=123" + assert validate_uri(uri) == uri + + def test_uri_with_fragment(self): + """Test that URIs with fragments pass validation.""" + uri = "http://example.com/path#section" + assert validate_uri(uri) == uri + + def test_uri_with_special_allowed_chars(self): + """Test that URIs with RFC 3986 allowed special characters pass.""" + uri = "http://example.com/path!$&'()*+,;=test" + assert validate_uri(uri) == uri + + def test_reject_uri_with_angle_brackets(self): + """Test that URIs containing angle brackets are rejected (SPARQL injection risk).""" + malicious_uris = [ + "http://example.com/", + ] + + for dangerous_input in dangerous_uris: + try: + validate_uri(dangerous_input) + pytest.fail( + f"Should have raised SecurityValidationError for: {dangerous_input}" + ) + except SecurityValidationError as e: + # Critical: verify the dangerous input is NOT in the error message + assert dangerous_input not in e.message, ( + f"SECURITY VULNERABILITY: Error message leaked user input. " + f"Message '{e.message}' contains '{dangerous_input}'" + ) + # Verify it's the expected generic message + assert e.message == "Malformed or unsafe URI detected." + + +class TestLogSecurity: + """Test suite to verify that logs do not leak sensitive user input.""" + + def test_uri_validation_logs_do_not_contain_raw_input(self, caplog): + """ + Test that log entries include fingerprint metadata, never raw malicious input. + + This prevents: + 1. Log injection attacks (e.g., newlines corrupting log structure) + 2. Information leakage through log files + """ + dangerous_uris = [ + "http://example.com/\nINJECTED_LOG_ENTRY", + "urn:test> } ; DROP GRAPH ", + "http://test.com/", + ] + + for dangerous_uri in dangerous_uris: + caplog.clear() + + try: + validate_uri(dangerous_uri) + except SecurityValidationError: + pass # Expected + + # Verify log was created + assert len(caplog.records) == 1 + log_message = caplog.records[0].message + + # Critical: raw dangerous input should NOT be in the log + assert dangerous_uri not in log_message, ( + f"SECURITY ISSUE: Log contains raw malicious input. " + f"Log: '{log_message}' contains '{dangerous_uri}'" + ) + + # Verify log contains safe metadata only + assert "fingerprint=" in log_message + assert "length=" in log_message + + def test_sort_order_validation_logs_do_not_contain_raw_input(self, caplog): + """ + Test that sort_order validation logs use fingerprint metadata and don't leak raw input. + """ + dangerous_inputs = [ + "ASC\n; DROP GRAPH ", + "DESC; DELETE WHERE { ?s ?p ?o }", + "UNION\r\nINJECTED_LOG", + ] + + for dangerous_input in dangerous_inputs: + caplog.clear() + + try: + validate_sort_order(dangerous_input) + except SecurityValidationError: + pass # Expected + + # Verify log was created + assert len(caplog.records) == 1 + log_message = caplog.records[0].message + + # Critical: raw dangerous input should NOT be in the log + assert dangerous_input not in log_message, ( + f"SECURITY ISSUE: Log contains raw malicious input. " + f"Log: '{log_message}' contains '{dangerous_input}'" + ) + + # Verify log contains safe metadata only + assert "fingerprint=" in log_message + assert "length=" in log_message + + def test_log_truncation_prevents_flooding(self, caplog): + """ + Test that extremely long malicious URIs are logged without raw content. + + This prevents log flooding attacks where attackers send very long + inputs to fill up disk space or make logs unreadable. + """ + # Create a very long malicious URI (1000 characters) + long_malicious_uri = "http://example.com/" + "A" * 1000 + "" + + caplog.clear() + + try: + validate_uri(long_malicious_uri) + except SecurityValidationError: + pass # Expected + + assert len(caplog.records) == 1 + log_message = caplog.records[0].message + + # Verify the full malicious URI is NOT in the log + assert long_malicious_uri not in log_message + + # The log should contain fixed-size safe metadata instead of snippets + assert "fingerprint=" in log_message + assert "length=" in log_message + + def test_non_string_type_logged_safely(self, caplog): + """ + Test that non-string types are logged as type names, not repr of content. + + This prevents potential issues with logging complex objects. + """ + non_string_inputs = [ + 123, + ["http://example.com"], + {"uri": "http://example.com"}, + ] + + for invalid_input in non_string_inputs: + caplog.clear() + + try: + validate_uri(invalid_input) + except SecurityValidationError: + pass # Expected + + assert len(caplog.records) == 1 + log_message = caplog.records[0].message + + # Should log the type name, not the actual content + assert type(invalid_input).__name__ in log_message + + # Should NOT contain the actual malicious content + assert str(invalid_input) not in log_message diff --git a/tdd/validators.py b/tdd/validators.py new file mode 100644 index 0000000..c6b2aaa --- /dev/null +++ b/tdd/validators.py @@ -0,0 +1,120 @@ +""" +Security validation module to prevent SPARQL and RDF injection attacks. +Enforces strict schema compliance and character allowlisting before data reaches the database layer. +""" + +import re +import logging +import hashlib +from typing import List, Optional + +from .errors import SecurityValidationError + +# Initialize module-level logger for security auditing +logger = logging.getLogger(__name__) + +# Strict regex for URI validation (RFC 3986 compliant). +# Allows standard URI characters INCLUDING percent-encoding ('%'). +# Explicitly rejects structural SPARQL characters ('<', '>', '{', '}', '^', '`', '|', '\\', spaces). +# This ensures attackers cannot break out of the wrapper in SPARQL queries. +URI_REGEX = re.compile(r"^[a-zA-Z0-9\-._~:/?#\[\]@!$&'()*+,;=%]+$") + + +def _input_fingerprint(value: str) -> str: + """Return a short non-reversible fingerprint for safe security logs.""" + return hashlib.sha256(value.encode("utf-8", "replace")).hexdigest()[:12] + + +def validate_uri(uri: str) -> str: + """ + Validates a URI string against injection patterns. + + This function enforces a strict allowlist of RFC 3986 compliant characters + to prevent SPARQL injection attacks. It blocks structural characters that + could break out of SPARQL query templates. + + Args: + uri: The URI string to validate (from user input) + + Returns: + The validated URI string (unchanged if valid) + + Raises: + SecurityValidationError: If the URI contains unsafe characters or is not a string + + Security Notes: + - Logs only non-reversible fingerprints (never attacker input) + - Returns generic error message to prevent attackers from probing validation rules + """ + if not isinstance(uri, str) or not URI_REGEX.match(uri): + if isinstance(uri, str): + logger.warning( + "SECURITY ALERT: Malformed or unsafe URI blocked. fingerprint=%s length=%d", + _input_fingerprint(uri), + len(uri), + ) + else: + logger.warning( + "SECURITY ALERT: Malformed or unsafe URI blocked. type=%s", + type(uri).__name__, + ) + # Generic error message - do not echo user input to prevent information leakage + raise SecurityValidationError("Malformed or unsafe URI detected.") + return uri + + +def validate_uris(uris: List[str]) -> List[str]: + """ + Validates a list of URIs. + """ + if not isinstance(uris, list): + logger.warning( + "SECURITY ALERT: Expected a list of URIs, received different type." + ) + raise SecurityValidationError("Expected a list of URIs.") + return [validate_uri(u) for u in uris] + + +def validate_sort_order(sort_order: Optional[str]) -> Optional[str]: + """ + Validates and normalizes sort order parameter using strict allowlist. + + This prevents SPARQL injection through the ORDER BY clause by only + allowing "ASC" or "DESC" values. + + Args: + sort_order: The sort order string ("asc", "desc", empty string, or None) + + Returns: + Normalized sort order ("ASC", "DESC", or None for empty/None input) + + Raises: + SecurityValidationError: If sort order is not in the allowlist + + Examples: + >>> validate_sort_order("asc") + "ASC" + >>> validate_sort_order("DESC") + "DESC" + >>> validate_sort_order(None) + None + >>> validate_sort_order("") + None + """ + if not sort_order: + return None + + normalized_order = sort_order.strip().upper() + + # After stripping, check if it's empty + if not normalized_order: + return None + + if normalized_order not in ["ASC", "DESC"]: + logger.warning( + "SECURITY ALERT: Invalid sort order blocked. fingerprint=%s length=%d", + _input_fingerprint(sort_order), + len(sort_order), + ) + raise SecurityValidationError("Invalid sort order.") + return normalized_order