From aa50338d138b3326b2773d871017565bc59d72a0 Mon Sep 17 00:00:00 2001 From: Gregor Lulic Date: Mon, 5 Sep 2022 10:08:29 +0200 Subject: [PATCH 1/3] add transaction to djongo --- djongo/base.py | 48 ++++++++++++++++++++++++++++-- djongo/cursor.py | 2 +- djongo/features.py | 2 +- djongo/models/fields.py | 14 ++++++--- djongo/sql2mongo/query.py | 61 ++++++++++++++++++++++++--------------- djongo/transaction.py | 44 +++++++++++++++++++++++++--- 6 files changed, 135 insertions(+), 36 deletions(-) diff --git a/djongo/base.py b/djongo/base.py index ba6766d7..1afb6b60 100644 --- a/djongo/base.py +++ b/djongo/base.py @@ -1,6 +1,7 @@ """ MongoDB database backend for Django """ +import traceback from collections import OrderedDict from logging import getLogger from django.db.backends.base.base import BaseDatabaseWrapper @@ -10,10 +11,12 @@ from .creation import DatabaseCreation from . import database as Database from .cursor import Cursor +from .database import DatabaseError from .features import DatabaseFeatures from .introspection import DatabaseIntrospection from .operations import DatabaseOperations from .schema import DatabaseSchemaEditor +from .transaction import Transaction logger = getLogger(__name__) @@ -34,9 +37,10 @@ def __contains__(self, item): class DjongoClient: - def __init__(self, database, enforce_schema=True): + def __init__(self, database, enforce_schema=True, session=None): self.enforce_schema = enforce_schema self.cached_collections = CachedCollections(database) + self.session = session class DatabaseWrapper(BaseDatabaseWrapper): @@ -115,6 +119,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): def __init__(self, *args, **kwargs): self.client_connection = None self.djongo_connection = None + self.transaction = None + self.rollbacked = False super().__init__(*args, **kwargs) def is_usable(self): @@ -187,7 +193,14 @@ def _set_autocommit(self, autocommit): TODO: For future reference, setting two phase commits and rollbacks might require populating this method. """ - pass + self.autocommit = False + + def set_autocommit( + self, autocommit, force_begin_transaction_with_broken_autocommit=False + ): + result = super().set_autocommit(autocommit, force_begin_transaction_with_broken_autocommit=False) + self.autocommit = False + return result def init_connection_state(self): try: @@ -220,3 +233,34 @@ def _commit(self): TODO: two phase commits are not supported yet. """ pass + + def _savepoint(self, sid): + self.in_atomic_block = True + self.transaction = Transaction(self.client_connection) + # self.client_connection = self.transaction.session.client + connection_params = self.get_connection_params() + + name = connection_params.pop('name') + # es = connection_params.pop('enforce_schema') + + # connection_params['document_class'] = OrderedDict + + # self.connection = self.transaction.__session.client[name] + self.djongo_connection.session = self.transaction.session + # self.client_connection[name]._session = self.transaction.session # noqa + self.client_connection[name].__setattr__('session', self.transaction.session) + # print(f'\n///////////////////self.client_connection[name].__name: {self.client_connection[name].__getattribute__("session")}//////////////////////\n') + # self.client_connection = self.transaction.__session.client + + + def _savepoint_commit(self, sid): + if not self.rollbacked: + self.transaction.__exit__(None, None, traceback) + self.djongo_connection = DjongoClient(self.connection, self.get_connection_params().get('enforce_schema')) + + def _savepoint_rollback(self, sid): + self.rollbacked = True + self.transaction.__exit__('DatabaseError', DatabaseError('Error in transaction; rollbacked'), traceback) + # self.transaction.session.abort_transaction() + # self.transaction.rollbacked = True + # self.djongo_connection = DjongoClient(self.connection, self.get_connection_params().get('enforce_schema')) diff --git a/djongo/cursor.py b/djongo/cursor.py index f5a4f4b8..f111373d 100644 --- a/djongo/cursor.py +++ b/djongo/cursor.py @@ -55,7 +55,7 @@ def execute(self, sql, params=None): sql, params) except Exception as e: - db_exe = DatabaseError() + db_exe = DatabaseError(str(e)) raise db_exe from e def fetchmany(self, size=1): diff --git a/djongo/features.py b/djongo/features.py index acfee1e3..1ff32515 100644 --- a/djongo/features.py +++ b/djongo/features.py @@ -7,7 +7,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_bulk_insert = True has_native_uuid_field = True supports_timezones = False - uses_savepoints = False + uses_savepoints = True can_clone_databases = True test_db_allows_multiple_connections = False supports_unspecified_pk = True diff --git a/djongo/models/fields.py b/djongo/models/fields.py index d34d8aa2..791f5ea6 100644 --- a/djongo/models/fields.py +++ b/djongo/models/fields.py @@ -781,7 +781,8 @@ def add(self, *objs): lh_field.get_attname(): getattr(self.instance, rh_field.get_attname()) } - } + }, + session=pymongo_connections[self.db].djongo_connection.session ) for obj in objs: fk_field = getattr(obj, lh_field.get_attname()) @@ -869,6 +870,8 @@ def add(self, *objs): fks.update(new_fks) db = router.db_for_write(self.instance.__class__, instance=self.instance) + print(f'\n/////////pymongo_connection: {pymongo_connections[self.db]}////////\n') + print(f'\n//////////_client._session: {pymongo_connections[self.db].djongo_connection.session}////////\n') self.instance_manager.db_manager(db).mongo_update_one( self._make_filter(), { @@ -877,7 +880,8 @@ def add(self, *objs): '$each': list(new_fks) } } - } + }, + session=pymongo_connections[self.db].djongo_connection.session ) add.alters_data = True @@ -903,7 +907,8 @@ def _remove(self, to_del): '$in': list(to_del) } } - } + }, + session=pymongo_connections[self.db].djongo_connection.session ) def clear(self): @@ -914,7 +919,8 @@ def clear(self): '$set': { self.field.attname: [] } - } + }, + session=pymongo_connections[self.db].djongo_connection.session ) setattr(self.instance, self.field.attname, set()) diff --git a/djongo/sql2mongo/query.py b/djongo/sql2mongo/query.py index 465c6879..53507f78 100644 --- a/djongo/sql2mongo/query.py +++ b/djongo/sql2mongo/query.py @@ -238,7 +238,7 @@ def _needs_column_selection(self): def _get_cursor(self): if self._needs_aggregation(): pipeline = self._make_pipeline() - cur = self.db[self.left_table].aggregate(pipeline) + cur = self.db[self.left_table].aggregate(pipeline, session=self.connection_properties.session) logger.debug(f'Aggregation query: {pipeline}') else: kwargs = {} @@ -257,7 +257,7 @@ def _get_cursor(self): if self.offset: kwargs.update(self.offset.to_mongo()) - cur = self.db[self.left_table].find(**kwargs) + cur = self.db[self.left_table].find(**kwargs, session=self.connection_properties.session) logger.debug(f'Find query: {kwargs}') return cur @@ -330,7 +330,7 @@ def parse(self): def execute(self): db = self.db - self.result = db[self.left_table].update_many(**self.kwargs) + self.result = db[self.left_table].update_many(**self.kwargs , session=self.connection_properties.session) logger.debug(f'update_many: {self.result.modified_count}, matched: {self.result.matched_count}') @@ -388,7 +388,8 @@ def execute(self): } }, {'$inc': {'auto.seq': num}}, - return_document=ReturnDocument.AFTER + return_document=ReturnDocument.AFTER, + session=self.connection_properties.session ) for i, val in enumerate(self._values): @@ -403,7 +404,7 @@ def execute(self): ins[_field] = value docs.append(ins) - res = self.db[self.left_table].insert_many(docs, ordered=False) + res = self.db[self.left_table].insert_many(docs, ordered=False, session=self.connection_properties.session) if auto: self._result_ref.last_row_id = auto['auto']['seq'] else: @@ -479,11 +480,12 @@ def _rename_column(self): '$rename': { self._old_name: self._new_name } - } + }, + session=self.connection_properties.session ) def _rename_collection(self): - self.db[self.left_table].rename(self._new_name) + self.db[self.left_table].rename(self._new_name, session=self.connection_properties.session) def _alter(self, statement: SQLStatement): self.execute = lambda: None @@ -510,7 +512,7 @@ def _alter(self, statement: SQLStatement): print_warn(feature) def _flush(self): - self.db[self.left_table].delete_many({}) + self.db[self.left_table].delete_many({}, session=self.connection_properties.session) def _table(self, statement: SQLStatement): tok = statement.next() @@ -535,7 +537,7 @@ def _drop(self, statement: SQLStatement): raise SQLDecodeError def _drop_index(self): - self.db[self.left_table].drop_index(self._iden_name) + self.db[self.left_table].drop_index(self._iden_name, session=self.connection_properties.session) def _drop_column(self): self.db[self.left_table].update_many( @@ -544,7 +546,8 @@ def _drop_column(self): '$unset': { self._iden_name: '' } - } + }, + session=self.connection_properties.session ) self.db['__schema__'].update_one( {'name': self.left_table}, @@ -552,7 +555,8 @@ def _drop_column(self): '$unset': { f'fields.{self._iden_name}': '' } - } + }, + session=self.connection_properties.session ) def _add(self, statement: SQLStatement): @@ -618,7 +622,8 @@ def _add_column(self): '$set': { self._iden_name: self._default } - } + }, + session=self.connection_properties.session ) self.db['__schema__'].update_one( {'name': self.left_table}, @@ -628,19 +633,24 @@ def _add_column(self): 'type_code': self._type_code } } - } + }, + session=self.connection_properties.session ) def _index(self): self.db[self.left_table].create_index( self.field_dir, - name=self._iden_name) + name=self._iden_name, + session=self.connection_properties.session + ) def _unique(self): self.db[self.left_table].create_index( self.field_dir, unique=True, - name=self._iden_name) + name=self._iden_name, + session=self.connection_properties.session + ) def _fk(self): pass @@ -653,15 +663,15 @@ def __init__(self, *args): def _create_table(self, statement): if '__schema__' not in self.connection_properties.cached_collections: - self.db.create_collection('__schema__') + self.db.create_collection('__schema__', session=self.connection_properties.session) self.connection_properties.cached_collections.add('__schema__') - self.db['__schema__'].create_index('name', unique=True) - self.db['__schema__'].create_index('auto') + self.db['__schema__'].create_index('name', unique=True, session=self.connection_properties.session) + self.db['__schema__'].create_index('auto', session=self.connection_properties.session) tok = statement.next() table = SQLToken.token2sql(tok, self).table try: - self.db.create_collection(table) + self.db.create_collection(table, session=self.connection_properties.session) except CollectionInvalid: if self.connection_properties.enforce_schema: raise @@ -708,10 +718,12 @@ def _create_table(self, statement): _set['auto.seq'] = 0 if SQLColumnDef.primarykey in col.col_constraints: - self.db[table].create_index(field, unique=True, name='__primary_key__') + self.db[table].create_index(field, unique=True, name='__primary_key__', + session=self.connection_properties.session) if SQLColumnDef.unique in col.col_constraints: - self.db[table].create_index(field, unique=True) + self.db[table].create_index(field, unique=True, + session=self.connection_properties.session) if (SQLColumnDef.not_null in col.col_constraints or SQLColumnDef.null in col.col_constraints): @@ -725,7 +737,8 @@ def _create_table(self, statement): self.db['__schema__'].update_one( filter=_filter, update=update, - upsert=True + upsert=True, + session=self.connection_properties.session, ) def parse(self): @@ -762,7 +775,7 @@ def parse(self): def execute(self): db_con = self.db - self.result = db_con[self.left_table].delete_many(**self.kw) + self.result = db_con[self.left_table].delete_many(**self.kw, session=self.connection_properties.session) logger.debug('delete_many: {}'.format(self.result.deleted_count)) def count(self): @@ -976,7 +989,7 @@ def _drop(self, sm): elif tok.match(tokens.Keyword, 'TABLE'): tok = statement.next() table_name = tok.get_name() - self.db.drop_collection(table_name) + self.db.drop_collection(table_name, session=self.connection_properties.session) else: raise SQLDecodeError('statement:{}'.format(sm)) diff --git a/djongo/transaction.py b/djongo/transaction.py index 0ce03de5..60e228a1 100644 --- a/djongo/transaction.py +++ b/djongo/transaction.py @@ -1,5 +1,41 @@ -from djongo.exceptions import NotSupportedError -from djongo import djongo_access_url +from pymongo import WriteConcern +from pymongo.errors import ConnectionFailure, OperationFailure +from pymongo.read_concern import ReadConcern -print(f'This version of djongo does not support transactions. Visit {djongo_access_url}') -raise NotSupportedError('transactions') + +def commit_or_rollback_with_retry(session, rollbacked=False): + # cannot retry bacause documentDB doesn't support retryable writes + while True: + try: + # Commit uses write concern set at transaction start. + if not rollbacked: + session.commit_transaction() + print("Transaction committed.") + break + except (ConnectionFailure, OperationFailure) as exc: + # Can retry commit + if exc.has_error_label("UnknownTransactionCommitResult"): + print("UnknownTransactionCommitResult, retrying " + "commit operation ...") + continue + else: + print("Error during commit ...") + raise + + +class Transaction: + def __init__(self, mongo_client): + self.mongo_client = mongo_client + self.session = self.mongo_client.start_session().__enter__() + self.rollbacked = False + + self.transaction = self.session.start_transaction( + read_concern=ReadConcern("snapshot"), + write_concern=WriteConcern(w="majority")) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + self.transaction.__exit__(exc_type, exc_val, traceback) + self.session.__exit__(exc_type, exc_val, traceback) From 08bc7de1acc69e75310bcc7abc5dc8201ae568e8 Mon Sep 17 00:00:00 2001 From: Gregor Lulic Date: Mon, 5 Sep 2022 10:22:25 +0200 Subject: [PATCH 2/3] add comments and remove redundant --- djongo/base.py | 22 +++++++++++----------- djongo/cursor.py | 2 ++ djongo/features.py | 2 ++ djongo/models/fields.py | 4 +--- djongo/sql2mongo/query.py | 2 +- djongo/transaction.py | 26 +++++--------------------- 6 files changed, 22 insertions(+), 36 deletions(-) diff --git a/djongo/base.py b/djongo/base.py index 1afb6b60..7a687b4a 100644 --- a/djongo/base.py +++ b/djongo/base.py @@ -1,6 +1,9 @@ """ MongoDB database backend for Django """ + +# THIS FILE WAS CHANGED ON - 05 Sep 2022 + import traceback from collections import OrderedDict from logging import getLogger @@ -235,32 +238,29 @@ def _commit(self): pass def _savepoint(self, sid): + # add _savepoint method to work with Django's transactions self.in_atomic_block = True self.transaction = Transaction(self.client_connection) - # self.client_connection = self.transaction.session.client connection_params = self.get_connection_params() name = connection_params.pop('name') - # es = connection_params.pop('enforce_schema') - # connection_params['document_class'] = OrderedDict - - # self.connection = self.transaction.__session.client[name] + # this will be used in sql2mongo/query.py as session parameter when using pymongo CRUD operations self.djongo_connection.session = self.transaction.session - # self.client_connection[name]._session = self.transaction.session # noqa + # this will be used in models/fields.py as session parameter when using pymongo CRUD operations + # in that file pymongo functions are prefixed with 'mongo_' self.client_connection[name].__setattr__('session', self.transaction.session) - # print(f'\n///////////////////self.client_connection[name].__name: {self.client_connection[name].__getattribute__("session")}//////////////////////\n') - # self.client_connection = self.transaction.__session.client def _savepoint_commit(self, sid): + # add _savepoint_commit method to work with Django's transactions + # this method is executed even if rollback is executed after it if not self.rollbacked: self.transaction.__exit__(None, None, traceback) self.djongo_connection = DjongoClient(self.connection, self.get_connection_params().get('enforce_schema')) def _savepoint_rollback(self, sid): self.rollbacked = True + # We have to pass in some error, but it is not used anywhere as far as known self.transaction.__exit__('DatabaseError', DatabaseError('Error in transaction; rollbacked'), traceback) - # self.transaction.session.abort_transaction() - # self.transaction.rollbacked = True - # self.djongo_connection = DjongoClient(self.connection, self.get_connection_params().get('enforce_schema')) + diff --git a/djongo/cursor.py b/djongo/cursor.py index f111373d..c51176ee 100644 --- a/djongo/cursor.py +++ b/djongo/cursor.py @@ -1,3 +1,5 @@ +# THIS FILE WAS CHANGED ON - 05 Sep 2022 + from logging import getLogger from .database import DatabaseError diff --git a/djongo/features.py b/djongo/features.py index 1ff32515..e63d447a 100644 --- a/djongo/features.py +++ b/djongo/features.py @@ -1,3 +1,5 @@ +# THIS FILE WAS CHANGED ON - 05 Sep 2022 + from django.db.backends.base.features import BaseDatabaseFeatures diff --git a/djongo/models/fields.py b/djongo/models/fields.py index 791f5ea6..b21b584b 100644 --- a/djongo/models/fields.py +++ b/djongo/models/fields.py @@ -13,7 +13,7 @@ These are the main fields for working with MongoDB. """ -# THIS FILE WAS CHANGED ON - 28 Mar 2022 +# THIS FILE WAS CHANGED ON - 05 Sep 2022 import functools import json @@ -870,8 +870,6 @@ def add(self, *objs): fks.update(new_fks) db = router.db_for_write(self.instance.__class__, instance=self.instance) - print(f'\n/////////pymongo_connection: {pymongo_connections[self.db]}////////\n') - print(f'\n//////////_client._session: {pymongo_connections[self.db].djongo_connection.session}////////\n') self.instance_manager.db_manager(db).mongo_update_one( self._make_filter(), { diff --git a/djongo/sql2mongo/query.py b/djongo/sql2mongo/query.py index 53507f78..4a634bd6 100644 --- a/djongo/sql2mongo/query.py +++ b/djongo/sql2mongo/query.py @@ -3,7 +3,7 @@ SQL constructors. """ -# THIS FILE WAS CHANGED ON - 19 Aug 2022 +# THIS FILE WAS CHANGED ON - 05 Sep 2022 import abc import re diff --git a/djongo/transaction.py b/djongo/transaction.py index 60e228a1..715eedcb 100644 --- a/djongo/transaction.py +++ b/djongo/transaction.py @@ -1,30 +1,12 @@ +# THIS FILE WAS CHANGED ON - 05 Sep 2022 + from pymongo import WriteConcern -from pymongo.errors import ConnectionFailure, OperationFailure from pymongo.read_concern import ReadConcern -def commit_or_rollback_with_retry(session, rollbacked=False): - # cannot retry bacause documentDB doesn't support retryable writes - while True: - try: - # Commit uses write concern set at transaction start. - if not rollbacked: - session.commit_transaction() - print("Transaction committed.") - break - except (ConnectionFailure, OperationFailure) as exc: - # Can retry commit - if exc.has_error_label("UnknownTransactionCommitResult"): - print("UnknownTransactionCommitResult, retrying " - "commit operation ...") - continue - else: - print("Error during commit ...") - raise - - class Transaction: def __init__(self, mongo_client): + # do initial steps for transaction as noted in mongo documentation for database version 4.0 self.mongo_client = mongo_client self.session = self.mongo_client.start_session().__enter__() self.rollbacked = False @@ -37,5 +19,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, traceback): + # cannot retry bacause documentDB doesn't support retryable writes + # exit transaction decorators as noted in mongo documentation for database version 4.0 self.transaction.__exit__(exc_type, exc_val, traceback) self.session.__exit__(exc_type, exc_val, traceback) From 8542c7d081d71d5ec4891b2a2bdc8915e3afcd27 Mon Sep 17 00:00:00 2001 From: Gregor Lulic Date: Mon, 5 Sep 2022 11:29:10 +0200 Subject: [PATCH 3/3] next operations work --- djongo/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/djongo/base.py b/djongo/base.py index 7a687b4a..e10b847f 100644 --- a/djongo/base.py +++ b/djongo/base.py @@ -255,9 +255,12 @@ def _savepoint(self, sid): def _savepoint_commit(self, sid): # add _savepoint_commit method to work with Django's transactions # this method is executed even if rollback is executed after it + connection_params = self.get_connection_params() + name = connection_params.pop('name') if not self.rollbacked: self.transaction.__exit__(None, None, traceback) self.djongo_connection = DjongoClient(self.connection, self.get_connection_params().get('enforce_schema')) + self.client_connection[name].__setattr__('session', None) def _savepoint_rollback(self, sid): self.rollbacked = True