Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions djongo/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
MongoDB database backend for Django
"""

# THIS FILE WAS CHANGED ON - 05 Sep 2022

import traceback
from collections import OrderedDict
from logging import getLogger
from django.db.backends.base.base import BaseDatabaseWrapper
Expand All @@ -10,10 +14,12 @@
from .creation import DatabaseCreation
from . import database as Database
from .cursor import Cursor
from .database import DatabaseError
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
from .transaction import Transaction

logger = getLogger(__name__)

Expand All @@ -34,9 +40,10 @@ def __contains__(self, item):

class DjongoClient:

def __init__(self, database, enforce_schema=True):
def __init__(self, database, enforce_schema=True, session=None):
self.enforce_schema = enforce_schema
self.cached_collections = CachedCollections(database)
self.session = session


class DatabaseWrapper(BaseDatabaseWrapper):
Expand Down Expand Up @@ -115,6 +122,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, *args, **kwargs):
self.client_connection = None
self.djongo_connection = None
self.transaction = None
self.rollbacked = False
super().__init__(*args, **kwargs)

def is_usable(self):
Expand Down Expand Up @@ -187,7 +196,14 @@ def _set_autocommit(self, autocommit):
TODO: For future reference, setting two phase commits and rollbacks
might require populating this method.
"""
pass
self.autocommit = False

def set_autocommit(
self, autocommit, force_begin_transaction_with_broken_autocommit=False
):
result = super().set_autocommit(autocommit, force_begin_transaction_with_broken_autocommit=False)
self.autocommit = False
return result

def init_connection_state(self):
try:
Expand Down Expand Up @@ -220,3 +236,34 @@ def _commit(self):
TODO: two phase commits are not supported yet.
"""
pass

def _savepoint(self, sid):
# add _savepoint method to work with Django's transactions
self.in_atomic_block = True
self.transaction = Transaction(self.client_connection)
connection_params = self.get_connection_params()

name = connection_params.pop('name')

# this will be used in sql2mongo/query.py as session parameter when using pymongo CRUD operations
self.djongo_connection.session = self.transaction.session
# this will be used in models/fields.py as session parameter when using pymongo CRUD operations
# in that file pymongo functions are prefixed with 'mongo_'
self.client_connection[name].__setattr__('session', self.transaction.session)


def _savepoint_commit(self, sid):
# add _savepoint_commit method to work with Django's transactions
# this method is executed even if rollback is executed after it
connection_params = self.get_connection_params()
name = connection_params.pop('name')
if not self.rollbacked:
self.transaction.__exit__(None, None, traceback)
self.djongo_connection = DjongoClient(self.connection, self.get_connection_params().get('enforce_schema'))
self.client_connection[name].__setattr__('session', None)

def _savepoint_rollback(self, sid):
self.rollbacked = True
# We have to pass in some error, but it is not used anywhere as far as known
self.transaction.__exit__('DatabaseError', DatabaseError('Error in transaction; rollbacked'), traceback)

4 changes: 3 additions & 1 deletion djongo/cursor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# THIS FILE WAS CHANGED ON - 05 Sep 2022

from logging import getLogger

from .database import DatabaseError
Expand Down Expand Up @@ -55,7 +57,7 @@ def execute(self, sql, params=None):
sql,
params)
except Exception as e:
db_exe = DatabaseError()
db_exe = DatabaseError(str(e))
raise db_exe from e

def fetchmany(self, size=1):
Expand Down
4 changes: 3 additions & 1 deletion djongo/features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# THIS FILE WAS CHANGED ON - 05 Sep 2022

from django.db.backends.base.features import BaseDatabaseFeatures


Expand All @@ -7,7 +9,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_bulk_insert = True
has_native_uuid_field = True
supports_timezones = False
uses_savepoints = False
uses_savepoints = True
can_clone_databases = True
test_db_allows_multiple_connections = False
supports_unspecified_pk = True
Expand Down
14 changes: 9 additions & 5 deletions djongo/models/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
These are the main fields for working with MongoDB.
"""

# THIS FILE WAS CHANGED ON - 28 Mar 2022
# THIS FILE WAS CHANGED ON - 05 Sep 2022

import functools
import json
Expand Down Expand Up @@ -781,7 +781,8 @@ def add(self, *objs):
lh_field.get_attname():
getattr(self.instance, rh_field.get_attname())
}
}
},
session=pymongo_connections[self.db].djongo_connection.session
)
for obj in objs:
fk_field = getattr(obj, lh_field.get_attname())
Expand Down Expand Up @@ -877,7 +878,8 @@ def add(self, *objs):
'$each': list(new_fks)
}
}
}
},
session=pymongo_connections[self.db].djongo_connection.session
)

add.alters_data = True
Expand All @@ -903,7 +905,8 @@ def _remove(self, to_del):
'$in': list(to_del)
}
}
}
},
session=pymongo_connections[self.db].djongo_connection.session
)

def clear(self):
Expand All @@ -914,7 +917,8 @@ def clear(self):
'$set': {
self.field.attname: []
}
}
},
session=pymongo_connections[self.db].djongo_connection.session
)
setattr(self.instance, self.field.attname, set())

Expand Down
63 changes: 38 additions & 25 deletions djongo/sql2mongo/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
SQL constructors.
"""

# THIS FILE WAS CHANGED ON - 19 Aug 2022
# THIS FILE WAS CHANGED ON - 05 Sep 2022

import abc
import re
Expand Down Expand Up @@ -238,7 +238,7 @@ def _needs_column_selection(self):
def _get_cursor(self):
if self._needs_aggregation():
pipeline = self._make_pipeline()
cur = self.db[self.left_table].aggregate(pipeline)
cur = self.db[self.left_table].aggregate(pipeline, session=self.connection_properties.session)
logger.debug(f'Aggregation query: {pipeline}')
else:
kwargs = {}
Expand All @@ -257,7 +257,7 @@ def _get_cursor(self):
if self.offset:
kwargs.update(self.offset.to_mongo())

cur = self.db[self.left_table].find(**kwargs)
cur = self.db[self.left_table].find(**kwargs, session=self.connection_properties.session)
logger.debug(f'Find query: {kwargs}')

return cur
Expand Down Expand Up @@ -330,7 +330,7 @@ def parse(self):

def execute(self):
db = self.db
self.result = db[self.left_table].update_many(**self.kwargs)
self.result = db[self.left_table].update_many(**self.kwargs , session=self.connection_properties.session)
logger.debug(f'update_many: {self.result.modified_count}, matched: {self.result.matched_count}')


Expand Down Expand Up @@ -388,7 +388,8 @@ def execute(self):
}
},
{'$inc': {'auto.seq': num}},
return_document=ReturnDocument.AFTER
return_document=ReturnDocument.AFTER,
session=self.connection_properties.session
)

for i, val in enumerate(self._values):
Expand All @@ -403,7 +404,7 @@ def execute(self):
ins[_field] = value
docs.append(ins)

res = self.db[self.left_table].insert_many(docs, ordered=False)
res = self.db[self.left_table].insert_many(docs, ordered=False, session=self.connection_properties.session)
if auto:
self._result_ref.last_row_id = auto['auto']['seq']
else:
Expand Down Expand Up @@ -479,11 +480,12 @@ def _rename_column(self):
'$rename': {
self._old_name: self._new_name
}
}
},
session=self.connection_properties.session
)

def _rename_collection(self):
self.db[self.left_table].rename(self._new_name)
self.db[self.left_table].rename(self._new_name, session=self.connection_properties.session)

def _alter(self, statement: SQLStatement):
self.execute = lambda: None
Expand All @@ -510,7 +512,7 @@ def _alter(self, statement: SQLStatement):
print_warn(feature)

def _flush(self):
self.db[self.left_table].delete_many({})
self.db[self.left_table].delete_many({}, session=self.connection_properties.session)

def _table(self, statement: SQLStatement):
tok = statement.next()
Expand All @@ -535,7 +537,7 @@ def _drop(self, statement: SQLStatement):
raise SQLDecodeError

def _drop_index(self):
self.db[self.left_table].drop_index(self._iden_name)
self.db[self.left_table].drop_index(self._iden_name, session=self.connection_properties.session)

def _drop_column(self):
self.db[self.left_table].update_many(
Expand All @@ -544,15 +546,17 @@ def _drop_column(self):
'$unset': {
self._iden_name: ''
}
}
},
session=self.connection_properties.session
)
self.db['__schema__'].update_one(
{'name': self.left_table},
{
'$unset': {
f'fields.{self._iden_name}': ''
}
}
},
session=self.connection_properties.session
)

def _add(self, statement: SQLStatement):
Expand Down Expand Up @@ -618,7 +622,8 @@ def _add_column(self):
'$set': {
self._iden_name: self._default
}
}
},
session=self.connection_properties.session
)
self.db['__schema__'].update_one(
{'name': self.left_table},
Expand All @@ -628,19 +633,24 @@ def _add_column(self):
'type_code': self._type_code
}
}
}
},
session=self.connection_properties.session
)

def _index(self):
self.db[self.left_table].create_index(
self.field_dir,
name=self._iden_name)
name=self._iden_name,
session=self.connection_properties.session
)

def _unique(self):
self.db[self.left_table].create_index(
self.field_dir,
unique=True,
name=self._iden_name)
name=self._iden_name,
session=self.connection_properties.session
)

def _fk(self):
pass
Expand All @@ -653,15 +663,15 @@ def __init__(self, *args):

def _create_table(self, statement):
if '__schema__' not in self.connection_properties.cached_collections:
self.db.create_collection('__schema__')
self.db.create_collection('__schema__', session=self.connection_properties.session)
self.connection_properties.cached_collections.add('__schema__')
self.db['__schema__'].create_index('name', unique=True)
self.db['__schema__'].create_index('auto')
self.db['__schema__'].create_index('name', unique=True, session=self.connection_properties.session)
self.db['__schema__'].create_index('auto', session=self.connection_properties.session)

tok = statement.next()
table = SQLToken.token2sql(tok, self).table
try:
self.db.create_collection(table)
self.db.create_collection(table, session=self.connection_properties.session)
except CollectionInvalid:
if self.connection_properties.enforce_schema:
raise
Expand Down Expand Up @@ -708,10 +718,12 @@ def _create_table(self, statement):
_set['auto.seq'] = 0

if SQLColumnDef.primarykey in col.col_constraints:
self.db[table].create_index(field, unique=True, name='__primary_key__')
self.db[table].create_index(field, unique=True, name='__primary_key__',
session=self.connection_properties.session)

if SQLColumnDef.unique in col.col_constraints:
self.db[table].create_index(field, unique=True)
self.db[table].create_index(field, unique=True,
session=self.connection_properties.session)

if (SQLColumnDef.not_null in col.col_constraints or
SQLColumnDef.null in col.col_constraints):
Expand All @@ -725,7 +737,8 @@ def _create_table(self, statement):
self.db['__schema__'].update_one(
filter=_filter,
update=update,
upsert=True
upsert=True,
session=self.connection_properties.session,
)

def parse(self):
Expand Down Expand Up @@ -762,7 +775,7 @@ def parse(self):

def execute(self):
db_con = self.db
self.result = db_con[self.left_table].delete_many(**self.kw)
self.result = db_con[self.left_table].delete_many(**self.kw, session=self.connection_properties.session)
logger.debug('delete_many: {}'.format(self.result.deleted_count))

def count(self):
Expand Down Expand Up @@ -976,7 +989,7 @@ def _drop(self, sm):
elif tok.match(tokens.Keyword, 'TABLE'):
tok = statement.next()
table_name = tok.get_name()
self.db.drop_collection(table_name)
self.db.drop_collection(table_name, session=self.connection_properties.session)
else:
raise SQLDecodeError('statement:{}'.format(sm))

Expand Down
Loading