diff --git a/docs/dev/taint_analysis.md b/docs/dev/taint_analysis.md new file mode 100644 index 00000000..7a4e0777 --- /dev/null +++ b/docs/dev/taint_analysis.md @@ -0,0 +1,94 @@ +# Taint Analysis - Backend Security + +Mellea backends implement thread security using the **SecLevel** model with capability-based access control and taint tracking. Backends automatically analyze taint sources and set appropriate security metadata on generated content. + +## Security Model + +The security system uses three types of security levels: + +```python +SecLevel := None | Classified of AccessType | TaintedBy of (CBlock | Component) +``` + +- **SecLevel.none()**: Safe content with no restrictions +- **SecLevel.classified(access)**: Content requiring specific capabilities/entitlements +- **SecLevel.tainted_by(source)**: Content tainted by a specific CBlock or Component + +## Backend Implementation + +All backends follow the same pattern using `ModelOutputThunk.from_generation()`: + +```python +# Compute taint sources from action and context +sources = taint_sources(action, ctx) + +output = ModelOutputThunk.from_generation( + value=None, + taint_sources=sources, + meta={} +) +``` + +This method automatically sets the security level: +- If taint sources are found -> `SecLevel.tainted_by(first_source)` +- If no taint sources -> `SecLevel.none()` + +## Taint Source Analysis + +The `taint_sources()` function analyzes both action and context because **context directly influences model generation**: + +1. **Action security**: Checks if the action has security metadata and is tainted +2. **Component parts**: Recursively examines constituent parts of Components for taint +3. **Context security**: Examines recent context items for tainted content (shallow check) + +**Example**: Even if the current action is safe, tainted context can influence the generated output. + +```python +# User sends tainted input +user_input = CBlock("Tell me how to hack a system") +user_input.mark_tainted() +ctx = ctx.add(user_input) + +# Safe action in tainted context +safe_action = CBlock("Explain general security concepts") + +# Generation finds tainted context +sources = taint_sources(safe_action, ctx) # Finds tainted user_input +# Model output will be influenced by the tainted context +``` + +## Security Metadata + +The `SecurityMetadata` class wraps `SecLevel` for integration with content blocks: + +```python +class SecurityMetadata: + def __init__(self, sec_level: SecLevel): + self.sec_level = sec_level + + def is_tainted(self) -> bool: + return self.sec_level.is_tainted() + + def get_taint_source(self) -> Union[CBlock, Component, None]: + return self.sec_level.get_taint_source() +``` + +Content can be marked as tainted: + +```python +component = CBlock("user input") +component.mark_tainted() # Sets SecLevel.tainted_by(component) + +if component._meta["_security"].is_tainted(): + print(f"Content tainted by: {component._meta['_security'].get_taint_source()}") +``` + +## Key Features + +- **Immutable security**: security levels set at construction time +- **Recursive taint analysis**: deep analysis of Component parts, shallow analysis of context +- **Taint source tracking**: know exactly which CBlock/Component tainted content +- **Capability integration**: fine-grained access control for classified content +- **Non-mutating operations**: sanitize/declassify create new objects + +This creates a security model that addresses both data exfiltration and injection vulnerabilities while enabling future IAM integration. \ No newline at end of file diff --git a/docs/examples/security/taint_example.py b/docs/examples/security/taint_example.py new file mode 100644 index 00000000..ef51ded8 --- /dev/null +++ b/docs/examples/security/taint_example.py @@ -0,0 +1,42 @@ +from mellea.stdlib.base import CBlock +from mellea.stdlib.session import MelleaSession +from mellea.backends.ollama import OllamaModelBackend +from mellea.security import privileged, SecurityError + +# Create tainted content +tainted_desc = CBlock("Process this sensitive data") +tainted_desc.mark_tainted() + +print(f"Original CBlock is tainted: {not tainted_desc.is_safe()}") + +# Create session +session = MelleaSession(OllamaModelBackend("llama3.2")) + +# Use tainted CBlock in session.instruct +print("Testing session.instruct with tainted CBlock...") +result = session.instruct( + description=tainted_desc, +) + +# The result should be tainted +print(f"Result is tainted: {not result.is_safe()}") +if not result.is_safe(): + taint_source = result._meta['_security'].get_taint_source() + print(f"Taint source: {taint_source}") + print("✅ SUCCESS: Taint preserved!") +else: + print("❌ FAIL: Result should be tainted but isn't!") + +# Mock privileged function that requires safe input +@privileged +def process_safe_data(data: CBlock) -> str: + """A function that requires safe (non-tainted) input.""" + return f"Processed: {data.value}" + +print("\nTesting privileged function with tainted result...") +try: + # This should raise a SecurityError + processed = process_safe_data(result) + print("❌ FAIL: Should have raised SecurityError!") +except SecurityError as e: + print(f"✅ SUCCESS: SecurityError raised - {e}") \ No newline at end of file diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 89b61536..e1e7bbc9 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -37,6 +37,7 @@ ModelOutputThunk, ModelToolCall, ) +from mellea.security import taint_sources from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement @@ -310,7 +311,14 @@ def _generate_from_chat_context_standard( **model_specific_options, ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + + output = ModelOutputThunk.from_generation( + value=None, + taint_sources=sources, + meta={} + ) output._context = linearized_context output._action = action output._model_options = model_opts diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 86c4509b..f9fc21e8 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -34,6 +34,7 @@ ModelOutputThunk, ModelToolCall, ) +from mellea.security import taint_sources from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement @@ -354,7 +355,14 @@ def generate_from_chat_context( format=_format.model_json_schema() if _format is not None else None, ) # type: ignore - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + + output = ModelOutputThunk.from_generation( + value=None, + taint_sources=sources, + meta={} + ) output._context = linearized_context output._action = action output._model_options = model_opts @@ -433,11 +441,16 @@ async def get_response(): result = None error = None if isinstance(response, BaseException): - result = ModelOutputThunk(value="") + result = ModelOutputThunk.from_generation( + value="", + taint_sources=taint_sources(actions[i], None), + meta={} + ) error = response else: - result = ModelOutputThunk( + result = ModelOutputThunk.from_generation( value=response.response, + taint_sources=taint_sources(actions[i], None), meta={"generate_response": response.model_dump()}, ) diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 39b026a8..123555fd 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -47,6 +47,7 @@ GenerateType, ModelOutputThunk, ) +from mellea.security import taint_sources from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement @@ -515,7 +516,14 @@ def _generate_from_chat_context_standard( ), ) # type: ignore - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + + output = ModelOutputThunk.from_generation( + value=None, + taint_sources=sources, + meta={} + ) output._context = linearized_context output._action = action output._model_options = model_opts @@ -690,11 +698,12 @@ def _generate_from_raw( assert isinstance(completion_response, Completion) results = [ - ModelOutputThunk( + ModelOutputThunk.from_generation( value=response.text, + taint_sources=taint_sources(actions[i], None), meta={"oai_completion_response": response.model_dump()}, ) - for response in completion_response.choices + for i, response in enumerate(completion_response.choices) ] for i, result in enumerate(results): diff --git a/mellea/security/__init__.py b/mellea/security/__init__.py new file mode 100644 index 00000000..68aae4ef --- /dev/null +++ b/mellea/security/__init__.py @@ -0,0 +1,25 @@ +"""Security module for mellea. + +This module provides security features for tracking and managing the security +level of content blocks and components in the mellea library. +""" + +from .core import ( + AccessType, + SecLevel, + SecurityMetadata, + SecurityError, + privileged, + declassify, + taint_sources, +) + +__all__ = [ + "AccessType", + "SecLevel", + "SecurityMetadata", + "SecurityError", + "privileged", + "declassify", + "taint_sources", +] \ No newline at end of file diff --git a/mellea/security/core.py b/mellea/security/core.py new file mode 100644 index 00000000..136d0c15 --- /dev/null +++ b/mellea/security/core.py @@ -0,0 +1,297 @@ +"""Core security functionality for mellea. + +This module provides the fundamental security classes and functions for +tracking security levels of content blocks and enforcing security policies. +""" + +import abc +import functools +from typing import Any, Callable, Generic, TypeVar, Union + +from mellea.stdlib.base import CBlock, Component + + +T = TypeVar('T') + + +class AccessType(Generic[T], abc.ABC): + """Abstract base class for access-based security. + + This trait allows integration with IAM systems and provides fine-grained + access control based on entitlements rather than coarse security levels. + """ + + @abc.abstractmethod + def has_access(self, entitlement: T | None) -> bool: + """Check if the given entitlement has access. + + Args: + entitlement: The entitlement to check (e.g., user role, IAM identifier) + + Returns: + True if the entitlement has access, False otherwise + """ + pass + + +class SecLevel(Generic[T]): + """Security level with access-based control and taint tracking. + + SecLevel := None | Classified of AccessType | TaintedBy of (CBlock | Component) + """ + + def __init__(self, level_type: str, data: Any = None): + """Initialize security level. + + Args: + level_type: Type of security level ("none", "classified", "tainted_by") + data: Associated data (AccessType for classified, CBlock/Component for tainted_by) + """ + self.level_type = level_type + self.data = data + + @classmethod + def none(cls) -> "SecLevel": + """Create a SecLevel with no restrictions (safe).""" + return cls("none") + + @classmethod + def classified(cls, access_type: AccessType[T]) -> "SecLevel": + """Create a SecLevel with classified access requirements.""" + return cls("classified", access_type) + + @classmethod + def tainted_by(cls, source: Union[CBlock, Component]) -> "SecLevel": + """Create a SecLevel tainted by a specific CBlock or Component.""" + return cls("tainted_by", source) + + def is_safe(self, entitlement: T | None = None) -> bool: + """Check if this security level is safe for the given entitlement. + + Args: + entitlement: The entitlement to check access for + + Returns: + True if safe, False if restricted + """ + if self.level_type == "none": + return True + elif self.level_type == "classified": + if self.data is None: + return False + return self.data.has_access(entitlement) + elif self.level_type == "tainted_by": + return False # Tainted content is never safe + else: + return False + + def is_tainted(self) -> bool: + """Check if this security level represents tainted content. + + Returns: + True if tainted, False otherwise + """ + return self.level_type == "tainted_by" + + def is_classified(self) -> bool: + """Check if this security level represents classified content. + + Returns: + True if classified, False otherwise + """ + return self.level_type == "classified" + + def get_taint_source(self) -> Union[CBlock, Component, None]: + """Get the source of taint if this is a tainted level. + + Returns: + The CBlock or Component that tainted this content, or None + """ + if self.level_type == "tainted_by": + return self.data + return None + + +class SecurityMetadata: + """Metadata for tracking security properties of content blocks.""" + + def __init__(self, sec_level: SecLevel): + """Initialize security metadata with a SecLevel. + + Args: + sec_level: The security level for this content + """ + self.sec_level = sec_level + + def is_safe(self, entitlement: Any = None) -> bool: + """Check if this security level is safe for the given entitlement. + + Args: + entitlement: The entitlement to check access for + + Returns: + True if safe, False if restricted + """ + return self.sec_level.is_safe(entitlement) + + def is_tainted(self) -> bool: + """Check if this security level represents tainted content. + + Returns: + True if tainted, False otherwise + """ + return self.sec_level.is_tainted() + + def is_classified(self) -> bool: + """Check if this security level represents classified content. + + Returns: + True if classified, False otherwise + """ + return self.sec_level.is_classified() + + def get_taint_source(self) -> Union[CBlock, Component, None]: + """Get the source of taint if this is a tainted level. + + Returns: + The CBlock or Component that tainted this content, or None + """ + return self.sec_level.get_taint_source() + + +class SecurityError(Exception): + """Exception raised for security-related errors.""" + pass + + +def taint_sources(action: Union[Component, CBlock], ctx: Any) -> list[Union[CBlock, Component]]: + """Compute taint sources from action and context. + + This function examines the action and context to determine what + security sources might be present. It performs recursive analysis + of Component parts and shallow analysis of context to identify + potential taint sources and returns the actual objects that are tainted. + + Args: + action: The action component or content block + ctx: The context containing previous interactions + + Returns: + List of tainted CBlocks or Components + """ + sources = [] + + # Check if action has security metadata and is tainted + if hasattr(action, '_meta') and '_security' in action._meta: + security_meta = action._meta['_security'] + if isinstance(security_meta, SecurityMetadata) and security_meta.is_tainted(): + sources.append(action) + + # For Components, check their constituent parts for taint + if hasattr(action, 'parts'): + try: + parts = action.parts() + for part in parts: + if hasattr(part, '_meta') and '_security' in part._meta: + security_meta = part._meta['_security'] + if isinstance(security_meta, SecurityMetadata) and security_meta.is_tainted(): + sources.append(part) + except Exception: + # If parts() fails, continue without it + pass + + # Check context for tainted content (shallow check) + if hasattr(ctx, 'as_list'): + try: + context_items = ctx.as_list(last_n_components=5) # Limit to recent items + for item in context_items: + if hasattr(item, '_meta') and '_security' in item._meta: + security_meta = item._meta['_security'] + if isinstance(security_meta, SecurityMetadata) and security_meta.is_tainted(): + sources.append(item) + except Exception: + # If context analysis fails, continue without it + pass + + return sources + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def privileged(func: F) -> F: + """Decorator to mark functions that require safe (non-tainted, non-classified) input. + + Functions decorated with @privileged will raise SecurityError if + called with tainted or classified content blocks. + + Args: + func: The function to decorate + + Returns: + The decorated function + + Raises: + SecurityError: If the function is called with tainted or classified content + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Check all arguments for unsafe content (tainted or classified) + for arg in args: + if isinstance(arg, CBlock) and hasattr(arg, '_meta') and '_security' in arg._meta: + security_meta = arg._meta['_security'] + if isinstance(security_meta, SecurityMetadata) and not security_meta.is_safe(): + if security_meta.is_tainted(): + taint_source = security_meta.get_taint_source() + source_info = f" (tainted by: {type(taint_source).__name__})" if taint_source else "" + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"tainted content{source_info}" + ) + elif security_meta.is_classified(): + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"classified content" + ) + + # Check keyword arguments for unsafe content (tainted or classified) + for key, value in kwargs.items(): + if isinstance(value, CBlock) and hasattr(value, '_meta') and '_security' in value._meta: + security_meta = value._meta['_security'] + if isinstance(security_meta, SecurityMetadata) and not security_meta.is_safe(): + if security_meta.is_tainted(): + taint_source = security_meta.get_taint_source() + source_info = f" (tainted by: {type(taint_source).__name__})" if taint_source else "" + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"tainted content in argument '{key}'{source_info}" + ) + elif security_meta.is_classified(): + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"classified content in argument '{key}'" + ) + + return func(*args, **kwargs) + + return wrapper # type: ignore + + +def declassify(cblock: CBlock) -> CBlock: + """Create a declassified version of a CBlock (non-mutating). + + This function creates a new CBlock with the same content but marked + as safe (SecLevel.none()). The original CBlock is not modified. + + Args: + cblock: The CBlock to declassify + + Returns: + A new CBlock with safe security level + """ + # Create new meta dict with safe security + new_meta = cblock._meta.copy() if cblock._meta else {} + new_meta['_security'] = SecurityMetadata(SecLevel.none()) + + # Return new CBlock with same content but new security metadata + return CBlock(cblock.value, new_meta) diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index bf0c1954..838f1f16 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -12,7 +12,7 @@ from copy import copy, deepcopy from dataclasses import dataclass from io import BytesIO -from typing import Any, Protocol, TypeVar, runtime_checkable +from typing import Any, Protocol, TypeVar, Union, runtime_checkable from PIL import Image as PILImage @@ -48,6 +48,43 @@ def __str__(self): def __repr__(self): """Provides a python-parsable representation of the block (usually).""" return f"CBlock({self.value}, {self._meta.__repr__()})" + + def mark_tainted(self, source: Union[CBlock, Component, None] = None): + """Mark this CBlock as tainted by a specific source. + + Args: + source: The CBlock or Component that tainted this content. If None, + this CBlock is marked as tainted by itself. + """ + from mellea.security import SecLevel, SecurityMetadata + + if self._meta is None: + self._meta = {} + + # If no source provided, taint by self + taint_source = source if source is not None else self + self._meta["_security"] = SecurityMetadata(SecLevel.tainted_by(taint_source)) + + def is_safe(self, entitlement: Any = None) -> bool: + """Check if this CBlock is considered safe for the given entitlement. + + Args: + entitlement: The entitlement to check access for (for classified content) + + Returns: + True if the block has no security metadata or is marked as safe, + False if it's marked as tainted or classified without proper entitlement + """ + from mellea.security import SecurityMetadata + + if self._meta is None or "_security" not in self._meta: + return True # Default to safe if no security metadata + + security_meta = self._meta["_security"] + if isinstance(security_meta, SecurityMetadata): + return security_meta.is_safe(entitlement) + + return True # Default to safe if security metadata is not the expected type class ImageBlock: @@ -321,6 +358,42 @@ def __repr__(self): Differs from CBlock because `._meta` can be very large for ModelOutputThunks. """ return f"ModelOutputThunk({self.value})" + + @classmethod + def from_generation( + cls, + value: str | None, + taint_sources: list[Union[CBlock, Component]] | None = None, + meta: dict[str, Any] | None = None, + parsed_repr: CBlock | Component | Any | None = None, + tool_calls: dict[str, ModelToolCall] | None = None, + ) -> "ModelOutputThunk": + """Create a ModelOutputThunk from generation with security metadata. + + Args: + value: The generated content + taint_sources: List of tainted CBlocks or Components from the generation context + meta: Additional metadata for the thunk + parsed_repr: Parsed representation of the output + tool_calls: Tool calls made during generation + + Returns: + A new ModelOutputThunk with appropriate security metadata + """ + if meta is None: + meta = {} + + # Add security metadata based on taint sources + from mellea.security import SecLevel, SecurityMetadata + + if taint_sources: + # If there are taint sources, mark as tainted by the first source + meta["_security"] = SecurityMetadata(SecLevel.tainted_by(taint_sources[0])) + else: + # If no taint sources, mark as safe + meta["_security"] = SecurityMetadata(SecLevel.none()) + + return cls(value, meta, parsed_repr, tool_calls) def __copy__(self): """Returns a shallow copy of the ModelOutputThunk. A copied ModelOutputThunk cannot be used for generation; don't copy over fields associated with generating.""" diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py index 5990edb2..65987321 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/funcs.py @@ -113,7 +113,7 @@ def act( @overload def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, @@ -134,7 +134,7 @@ def instruct( @overload def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, @@ -154,7 +154,7 @@ def instruct( def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, diff --git a/mellea/stdlib/instruction.py b/mellea/stdlib/instruction.py index f8d07efb..ccd7c6e0 100644 --- a/mellea/stdlib/instruction.py +++ b/mellea/stdlib/instruction.py @@ -119,11 +119,31 @@ def __init__( self._images = images self._repair_string: str | None = None - def parts(self): + def parts(self) -> list[Component | CBlock]: """Returns all of the constituent parts of an Instruction.""" - raise Exception( - "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" - ) + parts = [] + + # Add description if it exists + if self._description is not None: + parts.append(self._description) + + # Add prefix if it exists + if self._prefix is not None: + parts.append(self._prefix) + + # Add output_prefix if it exists + if self._output_prefix is not None: + parts.append(self._output_prefix) + + # Add icl_examples + parts.extend(self._icl_examples) + + # Add grounding_context values + for value in self._grounding_context.values(): + if isinstance(value, (CBlock, Component)): + parts.append(value) + + return parts def format_for_llm(self) -> TemplateRepresentation: """Formats the instruction for Formatter use.""" diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 2a63a71a..761ae2e3 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -310,7 +310,7 @@ def act( @overload def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -329,7 +329,7 @@ def instruct( @overload def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -345,9 +345,10 @@ def instruct( tool_calls: bool = False, ) -> SamplingResult: ... + def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -365,7 +366,7 @@ def instruct( """Generates from an instruction. Args: - description: The description of the instruction. + description: The description of the instruction (str or CBlock). requirements: A list of requirements that the instruction can be validated against. icl_examples: A list of in-context-learning examples that the instruction can be validated against. grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. diff --git a/test/stdlib_basics/test_security_comprehensive.py b/test/stdlib_basics/test_security_comprehensive.py new file mode 100644 index 00000000..c10f231d --- /dev/null +++ b/test/stdlib_basics/test_security_comprehensive.py @@ -0,0 +1,381 @@ +"""Comprehensive security tests for mellea thread security features.""" + +import pytest +from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext +from mellea.security import ( + AccessType, + SecLevel, + SecurityMetadata, + SecurityError, + privileged, + declassify, + taint_sources +) + + +class TestAccessType: + """Test AccessType functionality.""" + + def test_access_type_interface(self): + """Test that AccessType is an abstract base class.""" + with pytest.raises(TypeError): + AccessType() # Should not be instantiable directly + + def test_access_type_implementation(self): + """Test implementing AccessType.""" + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + assert access.has_access("admin") + assert not access.has_access("user") + assert not access.has_access(None) + + +class TestSecLevel: + """Test SecLevel functionality.""" + + def test_sec_level_none(self): + """Test SecLevel.none() creates safe level.""" + sec_level = SecLevel.none() + assert sec_level.level_type == "none" + assert sec_level.is_safe() + assert not sec_level.is_tainted() + assert not sec_level.is_classified() + + def test_sec_level_tainted_by(self): + """Test SecLevel.tainted_by() creates tainted level.""" + source = CBlock("source content") + sec_level = SecLevel.tainted_by(source) + assert sec_level.level_type == "tainted_by" + assert not sec_level.is_safe() + assert sec_level.is_tainted() + assert not sec_level.is_classified() + assert sec_level.get_taint_source() is source + + def test_sec_level_classified(self): + """Test SecLevel.classified() creates classified level.""" + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + assert sec_level.level_type == "classified" + assert sec_level.is_safe("admin") + assert not sec_level.is_safe("user") + assert not sec_level.is_safe(None) + assert not sec_level.is_tainted() + assert sec_level.is_classified() + + +class TestCBlockSecurity: + """Test CBlock security functionality.""" + + def test_cblock_mark_tainted(self): + """Test marking CBlock as tainted.""" + cblock = CBlock("test content") + cblock.mark_tainted() + + assert "_security" in cblock._meta + assert isinstance(cblock._meta["_security"], SecurityMetadata) + assert cblock._meta["_security"].is_tainted() + assert not cblock.is_safe() + + def test_cblock_mark_tainted_by_source(self): + """Test marking CBlock as tainted by another source.""" + source = CBlock("source content") + cblock = CBlock("test content") + cblock.mark_tainted(source) + + assert cblock._meta["_security"].is_tainted() + assert cblock._meta["_security"].get_taint_source() is source + + def test_cblock_default_safe(self): + """Test that CBlock defaults to safe when no security metadata.""" + cblock = CBlock("test content") + assert cblock.is_safe() + + def test_cblock_with_classified_metadata(self): + """Test CBlock with classified security metadata.""" + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + security_meta = SecurityMetadata(sec_level) + + cblock = CBlock("classified content", meta={"_security": security_meta}) + + assert cblock.is_safe("admin") + assert not cblock.is_safe("user") + assert not cblock.is_safe(None) + + +class TestDeclassify: + """Test declassify function.""" + + def test_declassify_creates_new_object(self): + """Test that declassify creates a new object without mutating original.""" + original = CBlock("test content") + original.mark_tainted() + + declassified = declassify(original) + + # Objects are different + assert original is not declassified + assert id(original) != id(declassified) + + # Content is preserved + assert original.value == declassified.value + + # Security levels are different + assert not original.is_safe() + assert declassified.is_safe() + assert declassified._meta["_security"].sec_level.level_type == "none" + + # Original is unchanged + assert original._meta["_security"].is_tainted() + + def test_declassify_preserves_other_metadata(self): + """Test that declassify preserves other metadata.""" + original = CBlock("test content", meta={"custom": "value", "other": 123}) + original.mark_tainted() + + declassified = declassify(original) + + assert declassified._meta["custom"] == "value" + assert declassified._meta["other"] == 123 + assert declassified._meta["_security"].sec_level.level_type == "none" + + +class TestPrivilegedDecorator: + """Test @privileged decorator functionality.""" + + def test_privileged_accepts_safe_input(self): + """Test that privileged functions accept safe input.""" + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + # CBlock with no security metadata defaults to safe + safe_cblock = CBlock("safe content") + + result = safe_function(safe_cblock) + assert result == "Processed: safe content" + + def test_privileged_accepts_declassified_input(self): + """Test that privileged functions accept declassified input.""" + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + tainted_cblock = CBlock("tainted content") + tainted_cblock.mark_tainted() + declassified_cblock = declassify(tainted_cblock) + + result = safe_function(declassified_cblock) + assert result == "Processed: tainted content" + + def test_privileged_rejects_tainted_input(self): + """Test that privileged functions reject tainted input.""" + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + tainted_cblock = CBlock("tainted content") + tainted_cblock.mark_tainted() + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(tainted_cblock) + + def test_privileged_rejects_classified_input(self): + """Test that privileged functions reject classified input without proper entitlement.""" + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + security_meta = SecurityMetadata(sec_level) + + classified_cblock = CBlock("classified content", meta={"_security": security_meta}) + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(classified_cblock) + + def test_privileged_accepts_no_security_metadata(self): + """Test that privileged functions accept input with no security metadata.""" + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + # CBlock with no security metadata defaults to safe + cblock = CBlock("content") + + result = safe_function(cblock) + assert result == "Processed: content" + + def test_privileged_with_kwargs(self): + """Test privileged function with keyword arguments.""" + @privileged + def safe_function(data: CBlock, prefix: str = "Processed: ") -> str: + return f"{prefix}{data.value}" + + tainted_cblock = CBlock("tainted content") + tainted_cblock.mark_tainted() + + with pytest.raises(SecurityError, match="argument 'data'"): + safe_function(data=tainted_cblock) + + +class TestTaintSources: + """Test taint source computation.""" + + def test_taint_sources_from_tainted_action(self): + """Test taint sources from tainted action.""" + action = CBlock("tainted action") + action.mark_tainted() + + sources = taint_sources(action, None) + assert len(sources) == 1 + assert sources[0] is action + + def test_taint_sources_from_safe_action(self): + """Test taint sources from safe action.""" + action = CBlock("safe action") + # No security metadata - defaults to safe + + sources = taint_sources(action, None) + assert len(sources) == 0 + + def test_taint_sources_from_context(self): + """Test taint sources from context.""" + action = CBlock("safe action") + + # Create context with tainted content + ctx = ChatContext() + tainted_cblock = CBlock("tainted context") + tainted_cblock.mark_tainted() + ctx = ctx.add(tainted_cblock) + + sources = taint_sources(action, ctx) + assert len(sources) == 1 + assert sources[0] is tainted_cblock + + def test_taint_sources_empty(self): + """Test taint sources with no tainted content.""" + action = CBlock("safe action") + ctx = ChatContext() + safe_cblock = CBlock("safe context") + # No security metadata - defaults to safe + ctx = ctx.add(safe_cblock) + + sources = taint_sources(action, ctx) + assert len(sources) == 0 + + +class TestModelOutputThunkSecurity: + """Test ModelOutputThunk security functionality.""" + + def test_from_generation_with_taint_sources(self): + """Test ModelOutputThunk.from_generation with taint sources.""" + taint_source = CBlock("taint source") + taint_source.mark_tainted() + + mot = ModelOutputThunk.from_generation( + value="generated content", + taint_sources=[taint_source], + meta={"custom": "value"} + ) + + assert mot.value == "generated content" + assert mot._meta["custom"] == "value" + assert "_security" in mot._meta + assert mot._meta["_security"].is_tainted() + assert not mot.is_safe() + assert mot._meta["_security"].get_taint_source() is taint_source + + def test_from_generation_without_taint_sources(self): + """Test ModelOutputThunk.from_generation without taint sources.""" + mot = ModelOutputThunk.from_generation( + value="generated content", + taint_sources=None, + meta={"custom": "value"} + ) + + assert mot.value == "generated content" + assert mot._meta["custom"] == "value" + assert "_security" in mot._meta + assert mot._meta["_security"].sec_level.level_type == "none" + assert mot.is_safe() + + def test_from_generation_empty_taint_sources(self): + """Test ModelOutputThunk.from_generation with empty taint sources.""" + mot = ModelOutputThunk.from_generation( + value="generated content", + taint_sources=[], + meta={"custom": "value"} + ) + + assert mot._meta["_security"].sec_level.level_type == "none" + assert mot.is_safe() + + +class TestSecurityIntegration: + """Test integration between security components.""" + + def test_security_flow_through_generation(self): + """Test security metadata flows through generation pipeline.""" + # Create tainted input + tainted_input = CBlock("user input") + tainted_input.mark_tainted() + + # Simulate generation with taint sources + sources = taint_sources(tainted_input, None) + mot = ModelOutputThunk.from_generation( + value="model response", + taint_sources=sources + ) + + # Verify output is tainted + assert not mot.is_safe() + assert mot._meta["_security"].is_tainted() + + # Declassify the output + safe_mot = declassify(mot) + assert safe_mot.is_safe() + assert safe_mot._meta["_security"].sec_level.level_type == "none" + + # Verify original is unchanged + assert not mot.is_safe() + + def test_privileged_function_with_generated_content(self): + """Test privileged function with generated content.""" + @privileged + def process_response(mot: ModelOutputThunk) -> str: + return f"Processed: {mot.value}" + + # Generate tainted content + taint_source = CBlock("taint source") + taint_source.mark_tainted() + + mot = ModelOutputThunk.from_generation( + value="tainted response", + taint_sources=[taint_source] + ) + + # Privileged function should reject tainted content + with pytest.raises(SecurityError): + process_response(mot) + + # Declassify and try again + safe_mot = declassify(mot) + result = process_response(safe_mot) + assert result == "Processed: tainted response"