From e8228f1158b26de755b2748c182e8dd09f486d2b Mon Sep 17 00:00:00 2001 From: Ilir Kokollari Date: Sat, 7 Dec 2024 17:53:52 +0000 Subject: [PATCH 1/2] Trigger ignore_others parameter addition --- docs/ignoring_triggers.md | 63 ++++++++ pgtrigger/compiler.py | 2 + pgtrigger/core.py | 135 ++++++++++++++---- .../tests/migrations/0014_topic_comment.py | 50 +++++++ pgtrigger/tests/models.py | 42 ++++++ pgtrigger/tests/test_core.py | 52 ++++++- 6 files changed, 317 insertions(+), 27 deletions(-) create mode 100644 pgtrigger/tests/migrations/0014_topic_comment.py diff --git a/docs/ignoring_triggers.md b/docs/ignoring_triggers.md index 9eeca7e..8cb4fe5 100644 --- a/docs/ignoring_triggers.md +++ b/docs/ignoring_triggers.md @@ -53,3 +53,66 @@ If you're ignoring triggers and handling database errors, there are two ways to 1. Wrap the outer transaction in `with pgtrigger.ignore.session():` so that the session is completed outside the transaction. 2. Wrap the inner `try/except` in `with transaction.atomic():` so that the errored part of the transaction is rolled back before the [pgtrigger.ignore][] context manager ends. + + +## Ignore other triggers within a trigger + +Provide an `ignore_others` list of trigger URIs you would like to ignore while executing +a certain trigger. See the example below for details: + +The `increment_comment_count` trigger will update the `comment_count` on a topic, instead of calculating +the count each time a topic is queried. Let's assume you are fixing a Justin Bieber Instagram +[bug](https://www.wired.com/2015/11/how-instagram-solved-its-justin-bieber-problem/). However we have +also protected the `comment_count` with a `pgtrigger.ReadOnly(name='read_only_comment_count')` trigger. + +In this case you would provide a `ignore_others=['tests.Topic:read_only_comment_count']` to the +`increment_comment_count` trigger. + +```python +class Topic(models.Model): + name = models.CharField(max_length=100) + comment_count = models.PositiveIntegerField(default=0) + + class Meta: + triggers = [ + pgtrigger.ReadOnly( + name='read_only_comment_count', + fields=['comment_count'] + ) + ] + + +class Comment(models.Model): + topic = models.ForeignKey(Topic, on_delete=models.CASCADE) + # Other fields + + class Meta: + triggers = [ + pgtrigger.Trigger( + func=pgtrigger.Func( + ''' + UPDATE "{db_table}" + SET "{comment_count}" = "{comment_count}" + 1 + WHERE + "{db_table}"."{topic_pk}" = NEW."{columns.topic}"; + {reset_ignore} + RETURN NEW; + ''', + db_table = Topic._meta.db_table, + comment_count = Topic._meta.get_field('comment_count').get_attname_column()[1], + topic_pk = Topic._meta.pk.get_attname_column()[1] + ), + ignore_others=['tests.Topic:read_only_comment_count'], + when=pgtrigger.Before, + operation=pgtrigger.Insert, + name='increment_comment_count' + ), + ] +``` + +!!! important + + Remember to use the `{reset_ignore}` placeholder in the trigger function before you return + from any branch. Without it the triggers you have ignored will persist throughout the session. + +It is mandatory to provide an instace of `pgtrigger.Func` to the `func` parameter. \ No newline at end of file diff --git a/pgtrigger/compiler.py b/pgtrigger/compiler.py index a7de1ad..f393719 100644 --- a/pgtrigger/compiler.py +++ b/pgtrigger/compiler.py @@ -65,6 +65,7 @@ def get_template(self): RETURN NEW; END IF; END IF; + {local_ignore} {func} END; $$ LANGUAGE plpgsql; @@ -118,6 +119,7 @@ def __init__( condition=_unset, execute=_unset, hash=None, + local_ignore='', ): """Initialize the SQL and store it in the `.data` attribute.""" self.kwargs = { diff --git a/pgtrigger/core.py b/pgtrigger/core.py index 0def798..2301534 100644 --- a/pgtrigger/core.py +++ b/pgtrigger/core.py @@ -27,6 +27,7 @@ else: raise AssertionError +from .version import __version__ as ver # Postgres only allows identifiers to be 63 chars max. Since "pgtrigger_" # is the prefix for trigger names, and since an additional "_" and @@ -510,10 +511,11 @@ class Func: possible to do inline SQL in the `Meta` of a model and reference its properties. """ - def __init__(self, func): + def __init__(self, func: str, **kwargs): self.func = func + self.kwargs = kwargs - def render(self, model: models.Model) -> str: + def render(self, model: type[models.Model], trigger: Trigger) -> str: """ Render the SQL of the function. @@ -523,9 +525,23 @@ def render(self, model: models.Model) -> str: Returns: The rendered SQL. """ - fields = utils.AttrDict({field.name: field for field in model._meta.fields}) - columns = utils.AttrDict({field.name: field.column for field in model._meta.fields}) - return self.func.format(meta=model._meta, fields=fields, columns=columns) + kwargs = dict( + meta=model._meta, + fields=utils.AttrDict({field.name: field for field in model._meta.fields}), + columns=utils.AttrDict({field.name: field.column for field in model._meta.fields}), + reset_ignore=( + ''' + IF _prev_ignore IS NOT NULL AND (_prev_ignore = '') IS NOT TRUE THEN + PERFORM set_config('pgtrigger.ignore', _prev_ignore, true); + ELSE + PERFORM set_config('pgtrigger.ignore', '', true); + END IF; + ''' + if trigger.ignores_others else '' + ) + ) | self.kwargs + + return self.func.format(**kwargs) # Allows Trigger methods to be used as context managers, mostly for @@ -559,6 +575,7 @@ class Trigger: func: Func | str | None = None declare: list[tuple[str, str]] | None = None timing: Timing | None = None + ignore_others: list[str] | None = None def __init__( self, @@ -572,6 +589,7 @@ def __init__( func: Func | str | None = None, declare: List[Tuple[str, str]] | None = None, timing: Timing | None = None, + ignore_others: list[str] | None = None, ) -> None: self.name = name or self.name self.level = level or self.level @@ -582,6 +600,7 @@ def __init__( self.func = func or self.func self.declare = declare or self.declare self.timing = timing or self.timing + self.ignore_others = ignore_others or self.ignore_others if not self.level or not isinstance(self.level, Level): raise ValueError(f'Invalid "level" attribute: {self.level}') @@ -609,6 +628,20 @@ def __init__( self.validate_name() + if self.ignores_others: + if not isinstance(self.func, Func): + raise ValueError( + 'Invalid "func" attribute. Triggers that ignore others must provide ' + f'an instance of pgtrigger.Func(). Received {type(self.func)} instead' + ) + + if not '{reset_ignore}' in self.func.func: + raise ValueError(f'Trigger "{self}" ignores other triggers, however, ' + 'placeholder {reset_ignore} was not found in the function ' + f'body. Please refer to: https://django-pgtrigger.readthedocs.io/en/{ver}/ignoring_triggers/#ignore-other-triggers-within-a-trigger') + + + def __str__(self) -> str: # pragma: no cover return self.name @@ -627,7 +660,7 @@ def validate_name(self) -> None: " Only alphanumeric characters, hyphens, and underscores are allowed." ) - def get_pgid(self, model: models.Model) -> str: + def get_pgid(self, model: type[models.Model]) -> str: """The ID of the trigger and function object in postgres All objects are prefixed with "pgtrigger_" in order to be @@ -650,7 +683,7 @@ def get_pgid(self, model: models.Model) -> str: # and pruning tasks. return pgid.lower() - def get_condition(self, model: models.Model) -> Condition: + def get_condition(self, model: type[models.Model]) -> Condition: """Get the condition of the trigger. Args: @@ -661,7 +694,7 @@ def get_condition(self, model: models.Model) -> Condition: """ return self.condition - def get_declare(self, model: models.Model) -> List[Tuple[str, str]]: + def get_declare(self, model: type[models.Model]) -> List[Tuple[str, str]]: """ Gets the DECLARE part of the trigger function if any variables are used. @@ -673,9 +706,58 @@ def get_declare(self, model: models.Model) -> List[Tuple[str, str]]: A list of variable name / type tuples that will be shown in the DECLARE. For example [('row_data', 'JSONB')] """ - return self.declare or [] + declare = self.declare or [] + + if self.ignore_others is not None: + declare.append(('_prev_ignore', 'text')) + declare.append(self.declare_local_ignore(self.ignore_others)) + + return declare + + def declare_local_ignore(self, ignore: list[str]) -> tuple[str, str]: + """Given a list of trigger URIs compile the value for `_local_ignore` + variable of the trigger function + + Parameters + ---------- + ignore : list[str] + List of trigger URIs + + Returns + ------- + tuple[str, str] + `_local_ignore` variable declaration and initial value for the DECLARE block + """ + local_ignore = '{' + ','.join( + f'{model._meta.db_table}:{(pgid:=trigger.get_pgid(model))},{pgid}' + for model, trigger in registry.registered(*ignore) + ) + '}' + return ('_local_ignore', f"text[] = '{local_ignore}'") + + @property + def ignores_others(self) -> bool: + """True if the trigger is initialized with local trigger ignores""" + return self.ignore_others is not None + + def render_local_ignore(self): + if self.ignores_others: + return ( + ''' + BEGIN + SELECT CURRENT_SETTING('pgtrigger.ignore', true) INTO _prev_ignore; + EXCEPTION WHEN OTHERS THEN + END; + + IF _prev_ignore IS NOT NULL AND (_prev_ignore = '') IS NOT TRUE THEN + SELECT _local_ignore || _prev_ignore::text[] INTO _local_ignore; + END IF; + + PERFORM set_config('pgtrigger.ignore', _local_ignore::text, true); + ''' + ) + return '' - def get_func(self, model: models.Model) -> Union[str, Func]: + def get_func(self, model: type[models.Model]) -> Union[str, Func]: """ Returns the trigger function that comes between the BEGIN and END clause. @@ -690,7 +772,7 @@ def get_func(self, model: models.Model) -> Union[str, Func]: raise ValueError("Must define func attribute or implement get_func") return self.func - def get_uri(self, model: models.Model) -> str: + def get_uri(self, model: type[models.Model]) -> str: """The URI for the trigger. Args: @@ -702,7 +784,7 @@ def get_uri(self, model: models.Model) -> str: return f"{model._meta.app_label}.{model._meta.object_name}:{self.name}" - def render_condition(self, model: models.Model) -> str: + def render_condition(self, model: type[models.Model]) -> str: """Renders the condition SQL in the trigger declaration. Args: @@ -721,7 +803,7 @@ def render_condition(self, model: models.Model) -> str: return resolved - def render_declare(self, model: models.Model) -> str: + def render_declare(self, model: type[models.Model]) -> str: """Renders the DECLARE of the trigger function, if any. Args: @@ -740,7 +822,7 @@ def render_declare(self, model: models.Model) -> str: return rendered_declare - def render_execute(self, model: models.Model) -> str: + def render_execute(self, model: type[models.Model]) -> str: """ Renders what should be executed by the trigger. This defaults to the trigger function. @@ -753,7 +835,7 @@ def render_execute(self, model: models.Model) -> str: """ return f"{self.get_pgid(model)}()" - def render_func(self, model: models.Model) -> str: + def render_func(self, model: type[models.Model]) -> str: """ Renders the func. @@ -766,11 +848,11 @@ def render_func(self, model: models.Model) -> str: func = self.get_func(model) if isinstance(func, Func): - return func.render(model) - else: - return func + return func.render(model, self) + + return func - def compile(self, model: models.Model) -> compiler.Trigger: + def compile(self, model: type[models.Model]) -> compiler.Trigger: """ Create a compiled representation of the trigger. useful for migrations. @@ -796,10 +878,11 @@ def compile(self, model: models.Model) -> compiler.Trigger: level=self.level, condition=self.render_condition(model), execute=self.render_execute(model), + local_ignore=self.render_local_ignore(), ), ) - def allow_migrate(self, model: models.Model, database: Union[str, None] = None) -> bool: + def allow_migrate(self, model: type[models.Model], database: Union[str, None] = None) -> bool: """True if the trigger for this model can be migrated. Defaults to using the router's allow_migrate. @@ -830,7 +913,7 @@ def format_sql(self, sql: str) -> str: def exec_sql( self, sql: str, - model: models.Model, + model: type[models.Model], database: Union[str, None] = None, fetchall: bool = False, ) -> Any: @@ -849,7 +932,7 @@ def exec_sql( return utils.exec_sql(str(sql), database=database, fetchall=fetchall) def get_installation_status( - self, model: models.Model, database: Union[str, None] = None + self, model: type[models.Model], database: Union[str, None] = None ) -> Tuple[str, Union[bool, None]]: """Returns the installation status of a trigger. @@ -922,7 +1005,7 @@ def unregister(self, *models: models.Model): return _cleanup_on_exit(lambda: self.register(*models)) - def install(self, model: models.Model, database: Union[str, None] = None): + def install(self, model: type[models.Model], database: Union[str, None] = None): """Installs the trigger for a model. Args: @@ -934,7 +1017,7 @@ def install(self, model: models.Model, database: Union[str, None] = None): self.exec_sql(install_sql, model, database=database) return _cleanup_on_exit(lambda: self.uninstall(model, database=database)) - def uninstall(self, model: models.Model, database: Union[str, None] = None): + def uninstall(self, model: type[models.Model], database: Union[str, None] = None): """Uninstalls the trigger for a model. Args: @@ -947,7 +1030,7 @@ def uninstall(self, model: models.Model, database: Union[str, None] = None): lambda: self.install(model, database=database) ) - def enable(self, model: models.Model, database: Union[str, None] = None): + def enable(self, model: type[models.Model], database: Union[str, None] = None): """Enables the trigger for a model. Args: @@ -960,7 +1043,7 @@ def enable(self, model: models.Model, database: Union[str, None] = None): lambda: self.disable(model, database=database) ) - def disable(self, model: models.Model, database: Union[str, None] = None): + def disable(self, model: type[models.Model], database: Union[str, None] = None): """Disables the trigger for a model. Args: diff --git a/pgtrigger/tests/migrations/0014_topic_comment.py b/pgtrigger/tests/migrations/0014_topic_comment.py new file mode 100644 index 0000000..88e8d98 --- /dev/null +++ b/pgtrigger/tests/migrations/0014_topic_comment.py @@ -0,0 +1,50 @@ +# Generated by Django 4.2.6 on 2024-12-07 10:47 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("tests", "0013_alter_testtrigger_m2m_field_changedcondition"), + ] + + operations = [ + migrations.CreateModel( + name="Topic", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=100)), + ("comment_count", models.PositiveIntegerField(default=0)), + ], + ), + migrations.CreateModel( + name="Comment", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "topic", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="tests.topic" + ), + ), + ], + ), + ] diff --git a/pgtrigger/tests/models.py b/pgtrigger/tests/models.py index 7c86632..3f3fc7e 100644 --- a/pgtrigger/tests/models.py +++ b/pgtrigger/tests/models.py @@ -259,3 +259,45 @@ class ChangedCondition(models.Model): fk_field = models.ForeignKey("auth.User", null=True, on_delete=models.CASCADE) char_pk_fk_field = models.ForeignKey(CharPk, null=True, on_delete=models.CASCADE) m2m_field = models.ManyToManyField(User, related_name="+") + + +class Topic(models.Model): + name = models.CharField(max_length=100) + comment_count = models.PositiveIntegerField(default=0) + + class Meta: + triggers = [ + pgtrigger.ReadOnly( + name='read_only_comment_count', + fields=['comment_count'] + ) + ] + + +class Comment(models.Model): + topic = models.ForeignKey(Topic, on_delete=models.CASCADE) + # Other fields + + class Meta: + triggers = [ + pgtrigger.Trigger( + func=pgtrigger.Func( + ''' + UPDATE "{db_table}" + SET "{comment_count}" = "{comment_count}" + 1 + WHERE + "{db_table}"."{topic_pk}" = NEW."{columns.topic}"; + {reset_ignore} + RETURN NEW; + ''', + db_table = Topic._meta.db_table, + comment_count = Topic._meta.get_field('comment_count').get_attname_column()[1], + topic_pk = Topic._meta.pk.get_attname_column()[1] + ), + ignore_others=['tests.Topic:read_only_comment_count'], + when=pgtrigger.Before, + operation=pgtrigger.Insert, + name='increment_comment_count' + ), + ] + diff --git a/pgtrigger/tests/test_core.py b/pgtrigger/tests/test_core.py index 8abb064..4c4aabb 100644 --- a/pgtrigger/tests/test_core.py +++ b/pgtrigger/tests/test_core.py @@ -10,7 +10,6 @@ from pgtrigger import core from pgtrigger.tests import models, utils - def test_func(): """Tests using custom Func object""" trigger = pgtrigger.Trigger( @@ -722,3 +721,54 @@ def test_trigger_management(mocker): pgtrigger.prune() pgtrigger.prune() deletion_protected_model.delete() + + +def test_func_is_func_when_ignores_others(): + """Tests that Trigger class enforces the use of Func() + when using `ignore_others`. Because the user needs to use `{reset_ignore}` + placeholder, so that `pgtrigger.ignore()` works as expected""" + + with pytest.raises(ValueError, + match='Invalid "func" attribute. Triggers that ignore others'): + pgtrigger.Trigger(name='ignorant-trigger', + func='RETURN NEW', + ignore_others=['app.Model:trigger'], + when=pgtrigger.Before, + operation=pgtrigger.Insert) + + with pytest.raises(ValueError, + match='placeholder {reset_ignore} was not found in the function body'): + pgtrigger.Trigger(name='ignorant-trigger', + func=pgtrigger.Func('RETURN NEW'), + ignore_others=['app.Model:trigger'], + when=pgtrigger.Before, + operation=pgtrigger.Insert) + + +@pytest.mark.django_db(transaction=True) +def test_ignore_others(): + """ + Test that igore others succesfully ignores the + `read_only_comment_count` trigger on `models.Topic` + + Test that the custom trigger on `models.Comment` updates the comment counter on `models.Topic` + + Test that `pgtrigger.ignore` works as expected + """ + test_topic = ddf.G(models.Topic) + test_message = ddf.G(models.Comment, topic=test_topic.pk) + + test_topic.refresh_from_db() + assert test_topic.comment_count == 1 + + + with pgtrigger.ignore('tests.Topic:read_only_comment_count'): + ddf.G(models.Comment, topic=test_topic.pk) + test_topic.comment_count = 10 + test_topic.save() + + + with utils.raises_trigger_error('Cannot update'): + test_topic.comment_count = 20 + test_topic.save() + \ No newline at end of file From 6ea49394fea87eb4e21d311427b901d0cbad0168 Mon Sep 17 00:00:00 2001 From: Ilir Kokollari Date: Sat, 7 Dec 2024 18:55:53 +0000 Subject: [PATCH 2/2] Trigger ignore_others + formatting --- pgtrigger/compiler.py | 2 +- pgtrigger/core.py | 71 ++++++++++--------- .../tests/migrations/0014_topic_comment.py | 3 +- pgtrigger/tests/models.py | 24 +++---- pgtrigger/tests/test_core.py | 48 +++++++------ 5 files changed, 73 insertions(+), 75 deletions(-) diff --git a/pgtrigger/compiler.py b/pgtrigger/compiler.py index f393719..f9cf864 100644 --- a/pgtrigger/compiler.py +++ b/pgtrigger/compiler.py @@ -119,7 +119,7 @@ def __init__( condition=_unset, execute=_unset, hash=None, - local_ignore='', + local_ignore="", ): """Initialize the SQL and store it in the `.data` attribute.""" self.kwargs = { diff --git a/pgtrigger/core.py b/pgtrigger/core.py index 2301534..8f315b3 100644 --- a/pgtrigger/core.py +++ b/pgtrigger/core.py @@ -525,22 +525,23 @@ def render(self, model: type[models.Model], trigger: Trigger) -> str: Returns: The rendered SQL. """ - kwargs = dict( - meta=model._meta, - fields=utils.AttrDict({field.name: field for field in model._meta.fields}), - columns=utils.AttrDict({field.name: field.column for field in model._meta.fields}), - reset_ignore=( - ''' + kwargs = { + "meta": model._meta, + "fields": utils.AttrDict({field.name: field for field in model._meta.fields}), + "columns": utils.AttrDict({field.name: field.column for field in model._meta.fields}), + "reset_ignore": ( + """ IF _prev_ignore IS NOT NULL AND (_prev_ignore = '') IS NOT TRUE THEN PERFORM set_config('pgtrigger.ignore', _prev_ignore, true); ELSE PERFORM set_config('pgtrigger.ignore', '', true); END IF; - ''' - if trigger.ignores_others else '' - ) - ) | self.kwargs - + """ + if trigger.ignores_others + else "" + ), + } | self.kwargs + return self.func.format(**kwargs) @@ -632,15 +633,15 @@ def __init__( if not isinstance(self.func, Func): raise ValueError( 'Invalid "func" attribute. Triggers that ignore others must provide ' - f'an instance of pgtrigger.Func(). Received {type(self.func)} instead' + f"an instance of pgtrigger.Func(). Received {type(self.func)} instead" + ) + + if "{reset_ignore}" not in self.func.func: + raise ValueError( + f'Trigger "{self}" ignores other triggers, however, ' + "placeholder {reset_ignore} was not found in the function " + f"body. Please refer to: https://django-pgtrigger.readthedocs.io/en/{ver}/ignoring_triggers/#ignore-other-triggers-within-a-trigger" ) - - if not '{reset_ignore}' in self.func.func: - raise ValueError(f'Trigger "{self}" ignores other triggers, however, ' - 'placeholder {reset_ignore} was not found in the function ' - f'body. Please refer to: https://django-pgtrigger.readthedocs.io/en/{ver}/ignoring_triggers/#ignore-other-triggers-within-a-trigger') - - def __str__(self) -> str: # pragma: no cover return self.name @@ -707,13 +708,13 @@ def get_declare(self, model: type[models.Model]) -> List[Tuple[str, str]]: be shown in the DECLARE. For example [('row_data', 'JSONB')] """ declare = self.declare or [] - + if self.ignore_others is not None: - declare.append(('_prev_ignore', 'text')) + declare.append(("_prev_ignore", "text")) declare.append(self.declare_local_ignore(self.ignore_others)) - + return declare - + def declare_local_ignore(self, ignore: list[str]) -> tuple[str, str]: """Given a list of trigger URIs compile the value for `_local_ignore` variable of the trigger function @@ -728,12 +729,16 @@ def declare_local_ignore(self, ignore: list[str]) -> tuple[str, str]: tuple[str, str] `_local_ignore` variable declaration and initial value for the DECLARE block """ - local_ignore = '{' + ','.join( - f'{model._meta.db_table}:{(pgid:=trigger.get_pgid(model))},{pgid}' - for model, trigger in registry.registered(*ignore) - ) + '}' - return ('_local_ignore', f"text[] = '{local_ignore}'") - + local_ignore = ( + "{" + + ",".join( + f"{model._meta.db_table}:{(pgid:=trigger.get_pgid(model))},{pgid}" + for model, trigger in registry.registered(*ignore) + ) + + "}" + ) + return ("_local_ignore", f"text[] = '{local_ignore}'") + @property def ignores_others(self) -> bool: """True if the trigger is initialized with local trigger ignores""" @@ -741,8 +746,7 @@ def ignores_others(self) -> bool: def render_local_ignore(self): if self.ignores_others: - return ( - ''' + return """ BEGIN SELECT CURRENT_SETTING('pgtrigger.ignore', true) INTO _prev_ignore; EXCEPTION WHEN OTHERS THEN @@ -753,9 +757,8 @@ def render_local_ignore(self): END IF; PERFORM set_config('pgtrigger.ignore', _local_ignore::text, true); - ''' - ) - return '' + """ + return "" def get_func(self, model: type[models.Model]) -> Union[str, Func]: """ diff --git a/pgtrigger/tests/migrations/0014_topic_comment.py b/pgtrigger/tests/migrations/0014_topic_comment.py index 88e8d98..e16d0cb 100644 --- a/pgtrigger/tests/migrations/0014_topic_comment.py +++ b/pgtrigger/tests/migrations/0014_topic_comment.py @@ -1,11 +1,10 @@ # Generated by Django 4.2.6 on 2024-12-07 10:47 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): - dependencies = [ ("tests", "0013_alter_testtrigger_m2m_field_changedcondition"), ] diff --git a/pgtrigger/tests/models.py b/pgtrigger/tests/models.py index 3f3fc7e..a719aad 100644 --- a/pgtrigger/tests/models.py +++ b/pgtrigger/tests/models.py @@ -266,12 +266,7 @@ class Topic(models.Model): comment_count = models.PositiveIntegerField(default=0) class Meta: - triggers = [ - pgtrigger.ReadOnly( - name='read_only_comment_count', - fields=['comment_count'] - ) - ] + triggers = [pgtrigger.ReadOnly(name="read_only_comment_count", fields=["comment_count"])] class Comment(models.Model): @@ -282,22 +277,21 @@ class Meta: triggers = [ pgtrigger.Trigger( func=pgtrigger.Func( - ''' + """ UPDATE "{db_table}" SET "{comment_count}" = "{comment_count}" + 1 WHERE "{db_table}"."{topic_pk}" = NEW."{columns.topic}"; {reset_ignore} RETURN NEW; - ''', - db_table = Topic._meta.db_table, - comment_count = Topic._meta.get_field('comment_count').get_attname_column()[1], - topic_pk = Topic._meta.pk.get_attname_column()[1] - ), - ignore_others=['tests.Topic:read_only_comment_count'], + """, + db_table=Topic._meta.db_table, + comment_count=Topic._meta.get_field("comment_count").get_attname_column()[1], + topic_pk=Topic._meta.pk.get_attname_column()[1], + ), + ignore_others=["tests.Topic:read_only_comment_count"], when=pgtrigger.Before, operation=pgtrigger.Insert, - name='increment_comment_count' + name="increment_comment_count", ), ] - diff --git a/pgtrigger/tests/test_core.py b/pgtrigger/tests/test_core.py index 4c4aabb..0385f28 100644 --- a/pgtrigger/tests/test_core.py +++ b/pgtrigger/tests/test_core.py @@ -10,6 +10,7 @@ from pgtrigger import core from pgtrigger.tests import models, utils + def test_func(): """Tests using custom Func object""" trigger = pgtrigger.Trigger( @@ -728,47 +729,48 @@ def test_func_is_func_when_ignores_others(): when using `ignore_others`. Because the user needs to use `{reset_ignore}` placeholder, so that `pgtrigger.ignore()` works as expected""" - with pytest.raises(ValueError, - match='Invalid "func" attribute. Triggers that ignore others'): - pgtrigger.Trigger(name='ignorant-trigger', - func='RETURN NEW', - ignore_others=['app.Model:trigger'], - when=pgtrigger.Before, - operation=pgtrigger.Insert) - - with pytest.raises(ValueError, - match='placeholder {reset_ignore} was not found in the function body'): - pgtrigger.Trigger(name='ignorant-trigger', - func=pgtrigger.Func('RETURN NEW'), - ignore_others=['app.Model:trigger'], - when=pgtrigger.Before, - operation=pgtrigger.Insert) - + with pytest.raises(ValueError, match='Invalid "func" attribute. Triggers that ignore others'): + pgtrigger.Trigger( + name="ignorant-trigger", + func="RETURN NEW", + ignore_others=["app.Model:trigger"], + when=pgtrigger.Before, + operation=pgtrigger.Insert, + ) + + with pytest.raises( + ValueError, match="placeholder {reset_ignore} was not found in the function body" + ): + pgtrigger.Trigger( + name="ignorant-trigger", + func=pgtrigger.Func("RETURN NEW"), + ignore_others=["app.Model:trigger"], + when=pgtrigger.Before, + operation=pgtrigger.Insert, + ) + @pytest.mark.django_db(transaction=True) def test_ignore_others(): """ Test that igore others succesfully ignores the `read_only_comment_count` trigger on `models.Topic` - + Test that the custom trigger on `models.Comment` updates the comment counter on `models.Topic` Test that `pgtrigger.ignore` works as expected """ test_topic = ddf.G(models.Topic) - test_message = ddf.G(models.Comment, topic=test_topic.pk) + ddf.G(models.Comment, topic=test_topic.pk) test_topic.refresh_from_db() assert test_topic.comment_count == 1 - - with pgtrigger.ignore('tests.Topic:read_only_comment_count'): + with pgtrigger.ignore("tests.Topic:read_only_comment_count"): ddf.G(models.Comment, topic=test_topic.pk) test_topic.comment_count = 10 test_topic.save() - - with utils.raises_trigger_error('Cannot update'): + with utils.raises_trigger_error("Cannot update"): test_topic.comment_count = 20 test_topic.save() - \ No newline at end of file