diff --git a/docs/ignoring_triggers.md b/docs/ignoring_triggers.md index 9eeca7e..8cb4fe5 100644 --- a/docs/ignoring_triggers.md +++ b/docs/ignoring_triggers.md @@ -53,3 +53,66 @@ If you're ignoring triggers and handling database errors, there are two ways to 1. Wrap the outer transaction in `with pgtrigger.ignore.session():` so that the session is completed outside the transaction. 2. Wrap the inner `try/except` in `with transaction.atomic():` so that the errored part of the transaction is rolled back before the [pgtrigger.ignore][] context manager ends. + + +## Ignore other triggers within a trigger + +Provide an `ignore_others` list of trigger URIs you would like to ignore while executing +a certain trigger. See the example below for details: + +The `increment_comment_count` trigger will update the `comment_count` on a topic, instead of calculating +the count each time a topic is queried. Let's assume you are fixing a Justin Bieber Instagram +[bug](https://www.wired.com/2015/11/how-instagram-solved-its-justin-bieber-problem/). However we have +also protected the `comment_count` with a `pgtrigger.ReadOnly(name='read_only_comment_count')` trigger. + +In this case you would provide a `ignore_others=['tests.Topic:read_only_comment_count']` to the +`increment_comment_count` trigger. + +```python +class Topic(models.Model): + name = models.CharField(max_length=100) + comment_count = models.PositiveIntegerField(default=0) + + class Meta: + triggers = [ + pgtrigger.ReadOnly( + name='read_only_comment_count', + fields=['comment_count'] + ) + ] + + +class Comment(models.Model): + topic = models.ForeignKey(Topic, on_delete=models.CASCADE) + # Other fields + + class Meta: + triggers = [ + pgtrigger.Trigger( + func=pgtrigger.Func( + ''' + UPDATE "{db_table}" + SET "{comment_count}" = "{comment_count}" + 1 + WHERE + "{db_table}"."{topic_pk}" = NEW."{columns.topic}"; + {reset_ignore} + RETURN NEW; + ''', + db_table = Topic._meta.db_table, + comment_count = Topic._meta.get_field('comment_count').get_attname_column()[1], + topic_pk = Topic._meta.pk.get_attname_column()[1] + ), + ignore_others=['tests.Topic:read_only_comment_count'], + when=pgtrigger.Before, + operation=pgtrigger.Insert, + name='increment_comment_count' + ), + ] +``` + +!!! important + + Remember to use the `{reset_ignore}` placeholder in the trigger function before you return + from any branch. Without it the triggers you have ignored will persist throughout the session. + +It is mandatory to provide an instace of `pgtrigger.Func` to the `func` parameter. \ No newline at end of file diff --git a/pgtrigger/compiler.py b/pgtrigger/compiler.py index a7de1ad..f9cf864 100644 --- a/pgtrigger/compiler.py +++ b/pgtrigger/compiler.py @@ -65,6 +65,7 @@ def get_template(self): RETURN NEW; END IF; END IF; + {local_ignore} {func} END; $$ LANGUAGE plpgsql; @@ -118,6 +119,7 @@ def __init__( condition=_unset, execute=_unset, hash=None, + local_ignore="", ): """Initialize the SQL and store it in the `.data` attribute.""" self.kwargs = { diff --git a/pgtrigger/core.py b/pgtrigger/core.py index 0def798..8f315b3 100644 --- a/pgtrigger/core.py +++ b/pgtrigger/core.py @@ -27,6 +27,7 @@ else: raise AssertionError +from .version import __version__ as ver # Postgres only allows identifiers to be 63 chars max. Since "pgtrigger_" # is the prefix for trigger names, and since an additional "_" and @@ -510,10 +511,11 @@ class Func: possible to do inline SQL in the `Meta` of a model and reference its properties. """ - def __init__(self, func): + def __init__(self, func: str, **kwargs): self.func = func + self.kwargs = kwargs - def render(self, model: models.Model) -> str: + def render(self, model: type[models.Model], trigger: Trigger) -> str: """ Render the SQL of the function. @@ -523,9 +525,24 @@ def render(self, model: models.Model) -> str: Returns: The rendered SQL. """ - fields = utils.AttrDict({field.name: field for field in model._meta.fields}) - columns = utils.AttrDict({field.name: field.column for field in model._meta.fields}) - return self.func.format(meta=model._meta, fields=fields, columns=columns) + kwargs = { + "meta": model._meta, + "fields": utils.AttrDict({field.name: field for field in model._meta.fields}), + "columns": utils.AttrDict({field.name: field.column for field in model._meta.fields}), + "reset_ignore": ( + """ + IF _prev_ignore IS NOT NULL AND (_prev_ignore = '') IS NOT TRUE THEN + PERFORM set_config('pgtrigger.ignore', _prev_ignore, true); + ELSE + PERFORM set_config('pgtrigger.ignore', '', true); + END IF; + """ + if trigger.ignores_others + else "" + ), + } | self.kwargs + + return self.func.format(**kwargs) # Allows Trigger methods to be used as context managers, mostly for @@ -559,6 +576,7 @@ class Trigger: func: Func | str | None = None declare: list[tuple[str, str]] | None = None timing: Timing | None = None + ignore_others: list[str] | None = None def __init__( self, @@ -572,6 +590,7 @@ def __init__( func: Func | str | None = None, declare: List[Tuple[str, str]] | None = None, timing: Timing | None = None, + ignore_others: list[str] | None = None, ) -> None: self.name = name or self.name self.level = level or self.level @@ -582,6 +601,7 @@ def __init__( self.func = func or self.func self.declare = declare or self.declare self.timing = timing or self.timing + self.ignore_others = ignore_others or self.ignore_others if not self.level or not isinstance(self.level, Level): raise ValueError(f'Invalid "level" attribute: {self.level}') @@ -609,6 +629,20 @@ def __init__( self.validate_name() + if self.ignores_others: + if not isinstance(self.func, Func): + raise ValueError( + 'Invalid "func" attribute. Triggers that ignore others must provide ' + f"an instance of pgtrigger.Func(). Received {type(self.func)} instead" + ) + + if "{reset_ignore}" not in self.func.func: + raise ValueError( + f'Trigger "{self}" ignores other triggers, however, ' + "placeholder {reset_ignore} was not found in the function " + f"body. Please refer to: https://django-pgtrigger.readthedocs.io/en/{ver}/ignoring_triggers/#ignore-other-triggers-within-a-trigger" + ) + def __str__(self) -> str: # pragma: no cover return self.name @@ -627,7 +661,7 @@ def validate_name(self) -> None: " Only alphanumeric characters, hyphens, and underscores are allowed." ) - def get_pgid(self, model: models.Model) -> str: + def get_pgid(self, model: type[models.Model]) -> str: """The ID of the trigger and function object in postgres All objects are prefixed with "pgtrigger_" in order to be @@ -650,7 +684,7 @@ def get_pgid(self, model: models.Model) -> str: # and pruning tasks. return pgid.lower() - def get_condition(self, model: models.Model) -> Condition: + def get_condition(self, model: type[models.Model]) -> Condition: """Get the condition of the trigger. Args: @@ -661,7 +695,7 @@ def get_condition(self, model: models.Model) -> Condition: """ return self.condition - def get_declare(self, model: models.Model) -> List[Tuple[str, str]]: + def get_declare(self, model: type[models.Model]) -> List[Tuple[str, str]]: """ Gets the DECLARE part of the trigger function if any variables are used. @@ -673,9 +707,60 @@ def get_declare(self, model: models.Model) -> List[Tuple[str, str]]: A list of variable name / type tuples that will be shown in the DECLARE. For example [('row_data', 'JSONB')] """ - return self.declare or [] + declare = self.declare or [] + + if self.ignore_others is not None: + declare.append(("_prev_ignore", "text")) + declare.append(self.declare_local_ignore(self.ignore_others)) + + return declare + + def declare_local_ignore(self, ignore: list[str]) -> tuple[str, str]: + """Given a list of trigger URIs compile the value for `_local_ignore` + variable of the trigger function - def get_func(self, model: models.Model) -> Union[str, Func]: + Parameters + ---------- + ignore : list[str] + List of trigger URIs + + Returns + ------- + tuple[str, str] + `_local_ignore` variable declaration and initial value for the DECLARE block + """ + local_ignore = ( + "{" + + ",".join( + f"{model._meta.db_table}:{(pgid:=trigger.get_pgid(model))},{pgid}" + for model, trigger in registry.registered(*ignore) + ) + + "}" + ) + return ("_local_ignore", f"text[] = '{local_ignore}'") + + @property + def ignores_others(self) -> bool: + """True if the trigger is initialized with local trigger ignores""" + return self.ignore_others is not None + + def render_local_ignore(self): + if self.ignores_others: + return """ + BEGIN + SELECT CURRENT_SETTING('pgtrigger.ignore', true) INTO _prev_ignore; + EXCEPTION WHEN OTHERS THEN + END; + + IF _prev_ignore IS NOT NULL AND (_prev_ignore = '') IS NOT TRUE THEN + SELECT _local_ignore || _prev_ignore::text[] INTO _local_ignore; + END IF; + + PERFORM set_config('pgtrigger.ignore', _local_ignore::text, true); + """ + return "" + + def get_func(self, model: type[models.Model]) -> Union[str, Func]: """ Returns the trigger function that comes between the BEGIN and END clause. @@ -690,7 +775,7 @@ def get_func(self, model: models.Model) -> Union[str, Func]: raise ValueError("Must define func attribute or implement get_func") return self.func - def get_uri(self, model: models.Model) -> str: + def get_uri(self, model: type[models.Model]) -> str: """The URI for the trigger. Args: @@ -702,7 +787,7 @@ def get_uri(self, model: models.Model) -> str: return f"{model._meta.app_label}.{model._meta.object_name}:{self.name}" - def render_condition(self, model: models.Model) -> str: + def render_condition(self, model: type[models.Model]) -> str: """Renders the condition SQL in the trigger declaration. Args: @@ -721,7 +806,7 @@ def render_condition(self, model: models.Model) -> str: return resolved - def render_declare(self, model: models.Model) -> str: + def render_declare(self, model: type[models.Model]) -> str: """Renders the DECLARE of the trigger function, if any. Args: @@ -740,7 +825,7 @@ def render_declare(self, model: models.Model) -> str: return rendered_declare - def render_execute(self, model: models.Model) -> str: + def render_execute(self, model: type[models.Model]) -> str: """ Renders what should be executed by the trigger. This defaults to the trigger function. @@ -753,7 +838,7 @@ def render_execute(self, model: models.Model) -> str: """ return f"{self.get_pgid(model)}()" - def render_func(self, model: models.Model) -> str: + def render_func(self, model: type[models.Model]) -> str: """ Renders the func. @@ -766,11 +851,11 @@ def render_func(self, model: models.Model) -> str: func = self.get_func(model) if isinstance(func, Func): - return func.render(model) - else: - return func + return func.render(model, self) + + return func - def compile(self, model: models.Model) -> compiler.Trigger: + def compile(self, model: type[models.Model]) -> compiler.Trigger: """ Create a compiled representation of the trigger. useful for migrations. @@ -796,10 +881,11 @@ def compile(self, model: models.Model) -> compiler.Trigger: level=self.level, condition=self.render_condition(model), execute=self.render_execute(model), + local_ignore=self.render_local_ignore(), ), ) - def allow_migrate(self, model: models.Model, database: Union[str, None] = None) -> bool: + def allow_migrate(self, model: type[models.Model], database: Union[str, None] = None) -> bool: """True if the trigger for this model can be migrated. Defaults to using the router's allow_migrate. @@ -830,7 +916,7 @@ def format_sql(self, sql: str) -> str: def exec_sql( self, sql: str, - model: models.Model, + model: type[models.Model], database: Union[str, None] = None, fetchall: bool = False, ) -> Any: @@ -849,7 +935,7 @@ def exec_sql( return utils.exec_sql(str(sql), database=database, fetchall=fetchall) def get_installation_status( - self, model: models.Model, database: Union[str, None] = None + self, model: type[models.Model], database: Union[str, None] = None ) -> Tuple[str, Union[bool, None]]: """Returns the installation status of a trigger. @@ -922,7 +1008,7 @@ def unregister(self, *models: models.Model): return _cleanup_on_exit(lambda: self.register(*models)) - def install(self, model: models.Model, database: Union[str, None] = None): + def install(self, model: type[models.Model], database: Union[str, None] = None): """Installs the trigger for a model. Args: @@ -934,7 +1020,7 @@ def install(self, model: models.Model, database: Union[str, None] = None): self.exec_sql(install_sql, model, database=database) return _cleanup_on_exit(lambda: self.uninstall(model, database=database)) - def uninstall(self, model: models.Model, database: Union[str, None] = None): + def uninstall(self, model: type[models.Model], database: Union[str, None] = None): """Uninstalls the trigger for a model. Args: @@ -947,7 +1033,7 @@ def uninstall(self, model: models.Model, database: Union[str, None] = None): lambda: self.install(model, database=database) ) - def enable(self, model: models.Model, database: Union[str, None] = None): + def enable(self, model: type[models.Model], database: Union[str, None] = None): """Enables the trigger for a model. Args: @@ -960,7 +1046,7 @@ def enable(self, model: models.Model, database: Union[str, None] = None): lambda: self.disable(model, database=database) ) - def disable(self, model: models.Model, database: Union[str, None] = None): + def disable(self, model: type[models.Model], database: Union[str, None] = None): """Disables the trigger for a model. Args: diff --git a/pgtrigger/tests/migrations/0014_topic_comment.py b/pgtrigger/tests/migrations/0014_topic_comment.py new file mode 100644 index 0000000..e16d0cb --- /dev/null +++ b/pgtrigger/tests/migrations/0014_topic_comment.py @@ -0,0 +1,49 @@ +# Generated by Django 4.2.6 on 2024-12-07 10:47 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("tests", "0013_alter_testtrigger_m2m_field_changedcondition"), + ] + + operations = [ + migrations.CreateModel( + name="Topic", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=100)), + ("comment_count", models.PositiveIntegerField(default=0)), + ], + ), + migrations.CreateModel( + name="Comment", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "topic", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="tests.topic" + ), + ), + ], + ), + ] diff --git a/pgtrigger/tests/models.py b/pgtrigger/tests/models.py index 7c86632..a719aad 100644 --- a/pgtrigger/tests/models.py +++ b/pgtrigger/tests/models.py @@ -259,3 +259,39 @@ class ChangedCondition(models.Model): fk_field = models.ForeignKey("auth.User", null=True, on_delete=models.CASCADE) char_pk_fk_field = models.ForeignKey(CharPk, null=True, on_delete=models.CASCADE) m2m_field = models.ManyToManyField(User, related_name="+") + + +class Topic(models.Model): + name = models.CharField(max_length=100) + comment_count = models.PositiveIntegerField(default=0) + + class Meta: + triggers = [pgtrigger.ReadOnly(name="read_only_comment_count", fields=["comment_count"])] + + +class Comment(models.Model): + topic = models.ForeignKey(Topic, on_delete=models.CASCADE) + # Other fields + + class Meta: + triggers = [ + pgtrigger.Trigger( + func=pgtrigger.Func( + """ + UPDATE "{db_table}" + SET "{comment_count}" = "{comment_count}" + 1 + WHERE + "{db_table}"."{topic_pk}" = NEW."{columns.topic}"; + {reset_ignore} + RETURN NEW; + """, + db_table=Topic._meta.db_table, + comment_count=Topic._meta.get_field("comment_count").get_attname_column()[1], + topic_pk=Topic._meta.pk.get_attname_column()[1], + ), + ignore_others=["tests.Topic:read_only_comment_count"], + when=pgtrigger.Before, + operation=pgtrigger.Insert, + name="increment_comment_count", + ), + ] diff --git a/pgtrigger/tests/test_core.py b/pgtrigger/tests/test_core.py index 8abb064..0385f28 100644 --- a/pgtrigger/tests/test_core.py +++ b/pgtrigger/tests/test_core.py @@ -722,3 +722,55 @@ def test_trigger_management(mocker): pgtrigger.prune() pgtrigger.prune() deletion_protected_model.delete() + + +def test_func_is_func_when_ignores_others(): + """Tests that Trigger class enforces the use of Func() + when using `ignore_others`. Because the user needs to use `{reset_ignore}` + placeholder, so that `pgtrigger.ignore()` works as expected""" + + with pytest.raises(ValueError, match='Invalid "func" attribute. Triggers that ignore others'): + pgtrigger.Trigger( + name="ignorant-trigger", + func="RETURN NEW", + ignore_others=["app.Model:trigger"], + when=pgtrigger.Before, + operation=pgtrigger.Insert, + ) + + with pytest.raises( + ValueError, match="placeholder {reset_ignore} was not found in the function body" + ): + pgtrigger.Trigger( + name="ignorant-trigger", + func=pgtrigger.Func("RETURN NEW"), + ignore_others=["app.Model:trigger"], + when=pgtrigger.Before, + operation=pgtrigger.Insert, + ) + + +@pytest.mark.django_db(transaction=True) +def test_ignore_others(): + """ + Test that igore others succesfully ignores the + `read_only_comment_count` trigger on `models.Topic` + + Test that the custom trigger on `models.Comment` updates the comment counter on `models.Topic` + + Test that `pgtrigger.ignore` works as expected + """ + test_topic = ddf.G(models.Topic) + ddf.G(models.Comment, topic=test_topic.pk) + + test_topic.refresh_from_db() + assert test_topic.comment_count == 1 + + with pgtrigger.ignore("tests.Topic:read_only_comment_count"): + ddf.G(models.Comment, topic=test_topic.pk) + test_topic.comment_count = 10 + test_topic.save() + + with utils.raises_trigger_error("Cannot update"): + test_topic.comment_count = 20 + test_topic.save()