diff --git a/README.rst b/README.rst index f4809dc71..031f6c3d2 100644 --- a/README.rst +++ b/README.rst @@ -110,7 +110,7 @@ A simple synchronous example: class BlogPost(Document): title = StringField(required=True, max_length=200) - posted = DateTimeField(default=datetime.datetime.utcnow) + posted = DateTimeField(default=lambda: datetime.datetime.now(datetime.UTC)) tags = ListField(StringField(max_length=50)) post = BlogPost( diff --git a/docs/apireference.rst b/docs/apireference.rst index 02f2240c0..1406bd854 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -107,7 +107,7 @@ Fields .. autoclass:: mongoengine.fields.BooleanField .. autoclass:: mongoengine.fields.DateTimeField .. autoclass:: mongoengine.fields.ComplexDateTimeField -.. autoclass:: mongoengine.fields.ZonedDateTimeField +.. autoclass:: mongoengine.fields.AwareDateTimeField .. autoclass:: mongoengine.fields.EmbeddedDocumentField .. autoclass:: mongoengine.fields.GenericEmbeddedDocumentField .. autoclass:: mongoengine.fields.DynamicField diff --git a/docs/code/tumblelog.py b/docs/code/tumblelog.py index 6ce14402c..a52b7a74d 100644 --- a/docs/code/tumblelog.py +++ b/docs/code/tumblelog.py @@ -80,7 +80,7 @@ class LinkPost(Post): async def run_async_tumblelog(): - await async_connect("tumblelog") + async_connect("tumblelog") await Post.adrop_collection() diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 71818fd5f..39cecf82d 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -16,7 +16,7 @@ function. The first argument is the name of the database to connect to:: The asynchronous alternative is :func:`~mongoengine.async_connect`:: from mongoengine import async_connect - await async_connect('project1') + async_connect('project1') By default, MongoEngine assumes that the :program:`mongod` instance is running on **localhost** on port **27017**. @@ -55,7 +55,7 @@ The asynchronous alternative is as follows:: # Connects to 'my_db' database by authenticating # with given credentials against that same database - await async_connect(host="mongodb://my_user:my_password@127.0.0.1:27017/my_db?authSource=my_db") + async_connect(host="mongodb://my_user:my_password@127.0.0.1:27017/my_db?authSource=my_db") The URI string can also be used to configure advanced parameters like ssl, replicaSet, etc. For more information or example about URI string, you can refer to the `official doc `_:: @@ -79,7 +79,7 @@ and :attr:`authentication_source` arguments should be provided:: The asynchronous alternative is as follows:: - await async_connect('my_db', username='my_user', password='my_password', authentication_source='admin') + async_connect('my_db', username='my_user', password='my_password', authentication_source='admin') The set of attributes that :func:`~mongoengine.connect` recognizes includes but is not limited to: :attr:`host`, :attr:`port`, :attr:`read_preference`, :attr:`username`, :attr:`password`, :attr:`authentication_source`, :attr:`authentication_mechanism`, @@ -171,11 +171,11 @@ connection globally:: The asynchronous alternative is :func:`~mongoengine.async_disconnect`:: from mongoengine import async_connect, async_disconnect - await async_connect('a_db', alias='db1') + async_connect('a_db', alias='db1') await async_disconnect(alias='db1') - await async_connect('another_db', alias='db1') + async_connect('another_db', alias='db1') .. note:: Calling :func:`~mongoengine.disconnect` without argument will disconnect the "default" connection diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index bbce33596..6983f3777 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -22,7 +22,7 @@ objects** as class attributes to the document class:: class Page(Document): title = StringField(max_length=200, required=True) - date_modified = DateTimeField(default=datetime.datetime.utcnow) + date_modified = DateTimeField(default=lambda: datetime.datetime.now(datetime.UTC)) As BSON (the binary format for storing data in mongodb) is order dependent, documents are serialized based on their field order. @@ -82,7 +82,7 @@ are as follows: * :class:`~mongoengine.fields.BooleanField` * :class:`~mongoengine.fields.ComplexDateTimeField` * :class:`~mongoengine.fields.DateTimeField` -* :class:`~mongoengine.fields.ZonedDateTimeField` +* :class:`~mongoengine.fields.AwareDateTimeField` * :class:`~mongoengine.fields.DecimalField` * :class:`~mongoengine.fields.DictField` * :class:`~mongoengine.fields.DynamicField` diff --git a/docs/guide/logging-monitoring.rst b/docs/guide/logging-monitoring.rst index a7ee5b6cc..58204ab0c 100644 --- a/docs/guide/logging-monitoring.rst +++ b/docs/guide/logging-monitoring.rst @@ -67,7 +67,7 @@ The following snippet provides a basic logging of all command events: # The asynchronous alternative is as follows: async def async_logging_example(): - await async_connect() + async_connect() log.info('GO ASYNC!') diff --git a/docs/tutorial.rst b/docs/tutorial.rst index 74f031d5a..a02aa7fa0 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -33,7 +33,7 @@ The asynchronous alternative is :func:`~mongoengine.async_connect`:: from mongoengine import async_connect - await async_connect('tumblelog') + async_connect('tumblelog') There are lots of options for connecting to MongoDB, for more information about them see the :ref:`guide-connecting` guide. diff --git a/mongoengine/asynchronous/connection.py b/mongoengine/asynchronous/connection.py index 6225de04b..e252562ae 100644 --- a/mongoengine/asynchronous/connection.py +++ b/mongoengine/asynchronous/connection.py @@ -1,5 +1,5 @@ from pymongo import AsyncMongoClient, ReadPreference -from pymongo.asynchronous import uri_parser +from pymongo.synchronous import uri_parser from pymongo.asynchronous.database import AsyncDatabase from pymongo.common import _UUID_REPRESENTATIONS from pymongo.driver_info import DriverInfo @@ -31,7 +31,7 @@ _dbs = {} -async def _async_get_connection_settings( +def _async_get_connection_settings( db=None, name=None, host=None, @@ -74,7 +74,7 @@ async def _async_get_connection_settings( resolved_hosts.append(entity) continue - uri_info = await uri_parser.parse_uri(entity) + uri_info = uri_parser.parse_uri(entity) resolved_hosts.append(entity) # override DB name from URI if provided @@ -122,7 +122,7 @@ async def _async_get_connection_settings( return conn_settings -async def async_register_connection( +def async_register_connection( alias, db=None, name=None, @@ -158,7 +158,7 @@ async def async_register_connection( for example, maxpoolsize, tz_aware, etc. See the documentation for pymongo's `MongoClient` for a full list. """ - conn_settings = await _async_get_connection_settings( + conn_settings = _async_get_connection_settings( db=db, name=name, host=host, @@ -218,13 +218,8 @@ def _create_connection(alias, mongo_client_class, **connection_settings): raise ConnectionFailure(f"Cannot connect to database {alias} :\n{e}") -async def async_get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): +def async_get_connection(alias=DEFAULT_CONNECTION_NAME): """Return a connection with a given alias.""" - - # Connect to the database if not already connected - if reconnect: - await async_disconnect(alias) - # If the requested alias already exists in the _connections list, return # it immediately. if alias in _connections and isinstance(_connections[alias], AsyncMongoClient): @@ -313,7 +308,7 @@ async def async_get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False) -> AsyncD await async_disconnect(alias) if alias not in _dbs or not isinstance(_dbs[alias], AsyncDatabase): - conn = await async_get_connection(alias) + conn = async_get_connection(alias) conn_settings = _connection_settings[alias] db = conn[conn_settings["name"]] # Authenticate if necessary @@ -321,7 +316,7 @@ async def async_get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False) -> AsyncD return _dbs[alias] -async def async_connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): +def async_connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): """Connect to the database specified by the 'db' argument. Connection settings may be provided here as well if the database is not @@ -339,7 +334,7 @@ async def async_connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): """ if alias in _connections: prev_conn_setting = _connection_settings[alias] - new_conn_settings = await _async_get_connection_settings(db, **kwargs) + new_conn_settings = _async_get_connection_settings(db, **kwargs) if new_conn_settings != prev_conn_setting: err_msg = ( "A different connection with alias `{}` was already " @@ -347,6 +342,6 @@ async def async_connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): ).format(alias) raise ConnectionFailure(err_msg) else: - await async_register_connection(alias, db, **kwargs) + async_register_connection(alias, db, **kwargs) - return await async_get_connection(alias) + return async_get_connection(alias) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 590557fe3..1a306255f 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -538,7 +538,7 @@ def __exit__(self, exc_type, exc, tb): # Async context manager # ------------------------------------------------------------------ async def __aenter__(self): - conn = await async_get_connection(self.alias) + conn = async_get_connection(self.alias) self._async_session_cm = conn.start_session(**self.session_kwargs) self._async_session = await self._async_session_cm.__aenter__() diff --git a/mongoengine/document.py b/mongoengine/document.py index cf0fcbd63..9c903d473 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -477,7 +477,7 @@ async def _aget_timeseries_collection(cls): db, include_system_collections=True ): collection = db[collection_name] - collection.options() + await collection.options() return collection opts = {"expireAfterSeconds": timeseries_opts.pop("expireAfterSeconds", None)} diff --git a/mongoengine/fields/datetime/aware_datetime_field.py b/mongoengine/fields/datetime/aware_datetime_field.py index f5a274916..49e0bf327 100644 --- a/mongoengine/fields/datetime/aware_datetime_field.py +++ b/mongoengine/fields/datetime/aware_datetime_field.py @@ -81,6 +81,8 @@ class Event (Document): meta = {'indexes': ['start_time.tz']} # Explicit nested field """ + AVAILABLE_TIMEZONES = available_timezones() # Compute this once, CPU intensive call + def __init__(self, *args, **kwargs): """ Initialize AwareDateTimeField. @@ -188,30 +190,35 @@ def to_python(self, value): return None if isinstance(value, datetime.datetime): - # Already a datetime object return value - if isinstance(value, dict) and "utc" in value and "tz" in value: - # Stored format: {"utc": datetime, "tz": "Asia/Kolkata"} - utc_dt = value["utc"] - tz_name = value["tz"] + if not (isinstance(value, dict) and "utc" in value and "tz" in value): + return None + + utc_dt = value["utc"] + tz_name = value["tz"] - if not isinstance(utc_dt, datetime.datetime): - return None + if not isinstance(utc_dt, datetime.datetime): + return None - # Ensure UTC datetime is timezone-aware - if utc_dt.tzinfo is None: - utc_dt = utc_dt.replace(tzinfo=UTC) + try: + tz = ZoneInfo(tz_name) + except Exception: + return utc_dt.replace(tzinfo=UTC) if utc_dt.tzinfo is None else utc_dt - # Convert from UTC to original timezone + # Prefer the ISO string: it preserves microseconds and the exact UTC offset. + iso_str = value.get("iso") + if isinstance(iso_str, str): try: - tz = ZoneInfo(tz_name) - return utc_dt.astimezone(tz) - except Exception: - # If timezone is invalid, return UTC - return utc_dt + tz_aware_dt = datetime.datetime.fromisoformat(iso_str) + if tz_aware_dt.tzinfo is None: + tz_aware_dt = tz_aware_dt.replace(tzinfo=UTC) + return tz_aware_dt.astimezone(tz) + except (ValueError, TypeError): + pass # fall through to UTC path - return None + # Fallback for documents stored before the iso field was added + return utc_dt.astimezone(tz) def to_mongo(self, value): """Convert Python datetime to MongoDB storage format.""" @@ -234,25 +241,19 @@ def to_mongo(self, value): "Use datetime.now(ZoneInfo('Asia/Kolkata')) or similar." ) - # Get timezone name + # Resolve the IANA timezone name from the tzinfo object tz_name = None if hasattr(value.tzinfo, "key"): - # pytz timezone + # pytz: zone.key == "Asia/Kolkata" tz_name = value.tzinfo.key - elif hasattr(value.tzinfo, "tzname"): - # Could be ZoneInfo or other - tz_name_str = value.tzinfo.tzname(value) - # For ZoneInfo, try to get the actual zone name - if hasattr(value.tzinfo, "__str__"): - # ZoneInfo's __str__ returns the zone name - zone_str = str(value.tzinfo) - # ZoneInfo zones are in available_timezones - if zone_str in available_timezones(): - tz_name = zone_str - else: - tz_name = tz_name_str + else: + # ZoneInfo: str(zone) == "Asia/Kolkata" + zone_str = str(value.tzinfo) + if zone_str in self.AVAILABLE_TIMEZONES: + tz_name = zone_str else: - tz_name = tz_name_str + # Last resort: tzname() — covers datetime.timezone.utc → "UTC" + tz_name = value.tzinfo.tzname(value) if not tz_name: self.error( @@ -260,11 +261,10 @@ def to_mongo(self, value): "Use ZoneInfo('Asia/Kolkata') or pytz.timezone('Asia/Kolkata')" ) - # Convert to UTC for storage - utc_dt = value.astimezone(UTC) - return { - "utc": utc_dt, + "utc": value.astimezone(UTC), + # ISO string preserves the original UTC offset and microseconds exactly + "iso": value.isoformat(), "tz": tz_name, } @@ -296,6 +296,14 @@ def lookup_member(self, member_name): field.db_field = "utc" field.name = member_name return field + elif member_name == "iso": + from mongoengine.fields.string import StringField + + # Return field type for nested iso string + field = StringField() + field.db_field = "iso" + field.name = member_name + return field elif member_name == "tz": from mongoengine.fields.string import StringField diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py index 41296872d..a5d9d4b99 100644 --- a/mongoengine/mongodb_support.py +++ b/mongoengine/mongodb_support.py @@ -43,7 +43,7 @@ async def async_get_mongodb_version(alias: str = DEFAULT_CONNECTION_NAME): cached = _VERSION_CACHE.get(alias) if cached is not None: return cached - conn = await async_get_connection(alias=alias) + conn = async_get_connection(alias=alias) version = tuple((await conn.server_info())["versionArray"][:2]) _VERSION_CACHE[alias] = version return version diff --git a/pyproject.toml b/pyproject.toml index 8a2266b76..2370860f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ test = [ "pillow (>=12.2)", "tox (>=4.54)", "tox-uv>=1.35.2", + "mongomock>=4.3.0" ] [project.urls] diff --git a/tests/asynchronous/all_warnings/test_warnings.py b/tests/asynchronous/all_warnings/test_warnings.py index 04f32ddf6..c6b01db28 100644 --- a/tests/asynchronous/all_warnings/test_warnings.py +++ b/tests/asynchronous/all_warnings/test_warnings.py @@ -4,26 +4,23 @@ top level and called first by the test suite. """ -import unittest import warnings from mongoengine import * from mongoengine.base.common import _document_registry -from tests.asynchronous.utils import reset_async_connections -from tests.utils import MONGO_TEST_DB +from tests.asynchronous.utils import MongoDBAsyncTestCase -class TestAllWarnings(unittest.IsolatedAsyncioTestCase): +class TestAllWarnings(MongoDBAsyncTestCase): async def asyncSetUp(self): - await async_connect(db=MONGO_TEST_DB) self.warning_list = [] self.showwarning_default = warnings.showwarning warnings.showwarning = self.append_to_warning_list + await super().asyncSetUp() async def asyncTearDown(self): warnings.showwarning = self.showwarning_default - await async_disconnect_all() - await reset_async_connections() + await super().asyncTearDown() def append_to_warning_list(self, message, category, *args): self.warning_list.append({"message": message, "category": category}) diff --git a/tests/asynchronous/document/test_class_methods.py b/tests/asynchronous/document/test_class_methods.py index b27c20e5e..fd9d8a6a2 100644 --- a/tests/asynchronous/document/test_class_methods.py +++ b/tests/asynchronous/document/test_class_methods.py @@ -1,16 +1,14 @@ import unittest from mongoengine import * -from mongoengine.pymongo_support import async_list_collection_names from mongoengine.base.queryset import NULLIFY, PULL -from tests.asynchronous.utils import reset_async_connections -from tests.utils import MONGO_TEST_DB +from mongoengine.pymongo_support import async_list_collection_names +from tests.asynchronous.utils import MongoDBAsyncTestCase -class TestClassMethods(unittest.IsolatedAsyncioTestCase): +class TestClassMethods(MongoDBAsyncTestCase): async def asyncSetUp(self): - await async_connect(db=MONGO_TEST_DB) - self.db = await async_get_db() + await super().asyncSetUp() class Person(Document): name = StringField() @@ -23,10 +21,7 @@ class Person(Document): self.Person = Person async def asyncTearDown(self): - for collection in await async_list_collection_names(self.db): - self.db.drop_collection(collection) - await async_disconnect() - await reset_async_connections() + await super().asyncTearDown() def test_definition(self): """Ensure that document may be defined using fields.""" diff --git a/tests/asynchronous/document/test_delta.py b/tests/asynchronous/document/test_delta.py index f7fe2cb98..f54d88761 100644 --- a/tests/asynchronous/document/test_delta.py +++ b/tests/asynchronous/document/test_delta.py @@ -1,13 +1,9 @@ from bson import SON from mongoengine import * -from mongoengine.pymongo_support import ( - async_list_collection_names, -) from tests.asynchronous.utils import ( MongoDBAsyncTestCase, async_get_as_pymongo, - reset_async_connections, ) @@ -25,12 +21,6 @@ class Person(Document): self.Person = Person - async def asyncTearDown(self): - for collection in await async_list_collection_names(self.db): - self.db.drop_collection(collection) - await async_disconnect() - await reset_async_connections() - async def test_delta(self): await self.delta(Document) await self.delta(DynamicDocument) diff --git a/tests/asynchronous/document/test_indexes.py b/tests/asynchronous/document/test_indexes.py index ce7128deb..be48e73bf 100644 --- a/tests/asynchronous/document/test_indexes.py +++ b/tests/asynchronous/document/test_indexes.py @@ -1,38 +1,36 @@ -import unittest from datetime import datetime import pytest from pymongo.collation import Collation from mongoengine import ( + DateTimeField, + DictField, Document, - StringField, - IntField, + DynamicDocument, EmbeddedDocument, EmbeddedDocumentField, + EmbeddedDocumentListField, + IntField, ListField, SortedListField, - DictField, - DynamicDocument, - DateTimeField, - EmbeddedDocumentListField, + StringField, ) -from mongoengine.asynchronous import async_connect, async_get_db, async_disconnect_all -from mongoengine.errors import OperationError, NotUniqueError +from mongoengine.asynchronous import async_connect +from mongoengine.errors import NotUniqueError, OperationError from mongoengine.mongodb_support import ( MONGODB_42, MONGODB_80, async_get_mongodb_version, ) -from mongoengine.registry import _CollectionRegistry -from tests.asynchronous.utils import reset_async_connections +from tests.asynchronous.utils import MongoDBAsyncTestCase from tests.utils import MONGO_TEST_DB -class TestIndexes(unittest.IsolatedAsyncioTestCase): +class TestIndexes(MongoDBAsyncTestCase): async def asyncSetUp(self): - self.connection = await async_connect(db=MONGO_TEST_DB) - self.db = async_get_db() + await super().asyncSetUp() + self.connection = self._connection class Person(Document): name = StringField() @@ -45,10 +43,10 @@ class Person(Document): self.Person = Person async def asyncTearDown(self): - await self.Person.adrop_collection() - await async_disconnect_all() - await reset_async_connections() - _CollectionRegistry.clear() + try: + await self.Person.adrop_collection() + finally: + await super().asyncTearDown() async def test_indexes_document(self): """Ensure that indexes are used when meta[indexes] is specified for @@ -996,7 +994,7 @@ async def test_indexes_after_database_drop(self): # Use a new connection and database since dropping the database could # cause concurrent tests to fail. tmp_alias = "test_indexes_after_database_drop" - connection = await async_connect(db=f"{MONGO_TEST_DB}_tempdb", alias=tmp_alias) + connection = async_connect(db=f"{MONGO_TEST_DB}_tempdb", alias=tmp_alias) class BlogPost(Document): slug = StringField(unique=True) diff --git a/tests/asynchronous/document/test_inheritance.py b/tests/asynchronous/document/test_inheritance.py index 9c39f6a5d..b19a3d61e 100644 --- a/tests/asynchronous/document/test_inheritance.py +++ b/tests/asynchronous/document/test_inheritance.py @@ -13,8 +13,8 @@ StringField, ) from mongoengine.pymongo_support import async_list_collection_names -from tests.fixtures import Base from tests.asynchronous.utils import MongoDBAsyncTestCase +from tests.fixtures import Base class TestInheritance(MongoDBAsyncTestCase): diff --git a/tests/asynchronous/document/test_instance.py b/tests/asynchronous/document/test_instance.py index 96cb001fe..5cf340605 100644 --- a/tests/asynchronous/document/test_instance.py +++ b/tests/asynchronous/document/test_instance.py @@ -22,15 +22,16 @@ from mongoengine import * from mongoengine import signals from mongoengine.asynchronous import ( - async_get_db, async_disconnect, + async_get_db, async_register_connection, ) from mongoengine.base import _DocumentRegistry +from mongoengine.base.queryset import CASCADE, DENY, NULLIFY, PULL, Q from mongoengine.context_managers import ( - switch_db, async_query_counter, switch_collection, + switch_db, ) from mongoengine.errors import ( FieldDoesNotExist, @@ -43,15 +44,8 @@ from mongoengine.pymongo_support import ( async_list_collection_names, ) -from mongoengine.base.queryset import NULLIFY, Q, CASCADE, PULL, DENY from mongoengine.registry import _CollectionRegistry from tests import fixtures -from tests.fixtures import ( - PickleDynamicEmbedded, - PickleDynamicTest, - PickleEmbedded, - PickleTest, -) from tests.asynchronous.fixtures import PickleSignalsTest from tests.asynchronous.utils import ( MongoDBAsyncTestCase, @@ -60,6 +54,12 @@ requires_mongodb_gte_44, reset_async_connections, ) +from tests.fixtures import ( + PickleDynamicEmbedded, + PickleDynamicTest, + PickleEmbedded, + PickleTest, +) from tests.utils import MONGO_TEST_DB TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "../fields/mongoengine.png") @@ -87,7 +87,7 @@ class Person(Document): async def asyncTearDown(self): for collection in await async_list_collection_names(self.db): - self.db.drop_collection(collection) + await self.db.drop_collection(collection) await super().asyncTearDown() await reset_async_connections() _CollectionRegistry.clear() @@ -2945,9 +2945,9 @@ async def test_db_alias_tests(self): """DB Alias tests.""" # mongoenginetest - Is default connection alias from setUp() # Register Aliases - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") - await async_register_connection("testdb-2", f"{MONGO_TEST_DB}_3") - await async_register_connection("testdb-3", f"{MONGO_TEST_DB}_4") + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-2", f"{MONGO_TEST_DB}_3") + async_register_connection("testdb-3", f"{MONGO_TEST_DB}_4") class User(Document): name = StringField() @@ -3017,7 +3017,7 @@ class AuthorBooks(Document): async def test_db_alias_overrides(self): """Test db_alias can be overriden.""" # Register a connection with db_alias testdb-2 - await async_register_connection("testdb-2", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-2", f"{MONGO_TEST_DB}_2") class A(Document): """Uses default db_alias""" @@ -3039,7 +3039,7 @@ class B(A): async def test_db_alias_propagates(self): """db_alias propagates?""" - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") class A(Document): name = StringField() @@ -3137,7 +3137,7 @@ def __str__(self): assert [str(b) async for b in custom_qs] == ["1", "2"] async def test_switch_db_instance(self): - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") class Group(Document): name = StringField() @@ -3186,8 +3186,8 @@ class Group(Document): assert "hello - default" == group.name async def test_switch_db_multiple_documents_same_context(self): - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") - await async_register_connection("testdb-2", f"{MONGO_TEST_DB}_3") + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-2", f"{MONGO_TEST_DB}_3") class Group(Document): name = StringField() @@ -3243,7 +3243,7 @@ class Post(Document): assert p2.title == "post-testdb-2" async def test_switch_db_and_switch_collection_instance(self): - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") class Group(Document): name = StringField() @@ -3306,8 +3306,8 @@ class Group(Document): assert "hello - default" == g0.name async def test_switch_multiple_db_and_multiple_collection_same_time(self): - await async_register_connection("tenantA", f"{MONGO_TEST_DB}_2") - await async_register_connection("tenantB", f"{MONGO_TEST_DB}_2") + async_register_connection("tenantA", f"{MONGO_TEST_DB}_2") + async_register_connection("tenantB", f"{MONGO_TEST_DB}_2") class User(Document): name = StringField() @@ -3789,7 +3789,7 @@ class Test(Document): async def test_default_values_dont_get_override_upon_save_when_only_is_used(self): class Person(Document): - created_on = DateTimeField(default=lambda: datetime.utcnow()) + created_on = DateTimeField(default=lambda: datetime.now(UTC)) name = StringField() p = Person(name="alon") @@ -3805,7 +3805,7 @@ class Person(Document): assert orig_created_on == p3.created_on class Person(Document): - created_on = DateTimeField(default=lambda: datetime.utcnow()) + created_on = DateTimeField(default=lambda: datetime.now(UTC)) name = StringField() height = IntField(default=189) diff --git a/tests/asynchronous/document/test_timeseries_collection.py b/tests/asynchronous/document/test_timeseries_collection.py index 281a0757e..4f8d9df14 100644 --- a/tests/asynchronous/document/test_timeseries_collection.py +++ b/tests/asynchronous/document/test_timeseries_collection.py @@ -1,7 +1,5 @@ import asyncio -import unittest from datetime import datetime, timedelta -from tests.utils import MONGO_TEST_DB try: # Python 3.11+ @@ -18,15 +16,12 @@ FloatField, StringField, ) -from mongoengine.asynchronous import async_connect, async_get_db, async_disconnect -from mongoengine.registry import _CollectionRegistry -from tests.asynchronous.utils import requires_mongodb_gte_50 +from tests.asynchronous.utils import MongoDBAsyncTestCase, requires_mongodb_gte_50 -class TestTimeSeriesCollections(unittest.IsolatedAsyncioTestCase): +class TestTimeSeriesCollections(MongoDBAsyncTestCase): async def asyncSetUp(self): - await async_connect(db=MONGO_TEST_DB) - self.db = await async_get_db() + await super().asyncSetUp() class SensorData(Document): timestamp = DateTimeField(required=True) @@ -45,20 +40,18 @@ class SensorData(Document): self.SensorData = SensorData async def asyncTearDown(self): - await super().asyncTearDown() - _CollectionRegistry.clear() + try: + for collection_name in await self.db.list_collection_names(): + if not collection_name.startswith("system."): + await self.db.drop_collection(collection_name) + finally: + await super().asyncTearDown() async def test_get_db(self): """Ensure that get_db returns the expected db.""" db = await self.SensorData._async_get_db() assert self.db == db - async def asyncTearDown(self): - for collection_name in await self.db.list_collection_names(): - if not collection_name.startswith("system."): - await self.db.drop_collection(collection_name) - await async_disconnect() - async def test_definition(self): """Ensure that document may be defined using fields.""" assert ["id", "temperature", "timestamp"] == sorted( diff --git a/tests/asynchronous/fields/test_aware_datetime_field.py b/tests/asynchronous/fields/test_aware_datetime_field.py index 066b4a392..0905e037e 100644 --- a/tests/asynchronous/fields/test_aware_datetime_field.py +++ b/tests/asynchronous/fields/test_aware_datetime_field.py @@ -30,9 +30,7 @@ class Event(Document): await Event.adrop_collection() # Create event with Asia/Kolkata timezone - kolkata_time = datetime.datetime( - 2024, 6, 15, 14, 30, tzinfo=ZoneInfo("Asia/Kolkata") - ) + kolkata_time = datetime.datetime.now().astimezone(ZoneInfo("Asia/Kolkata")) event = Event(start_time=kolkata_time) await event.asave() @@ -40,6 +38,7 @@ class Event(Document): raw = await async_get_as_pymongo(event) assert "start_time" in raw assert "utc" in raw["start_time"] + assert "iso" in raw["start_time"] assert "tz" in raw["start_time"] assert raw["start_time"]["tz"] == "Asia/Kolkata" @@ -59,24 +58,11 @@ class Event(Document): await Event.adrop_collection() # Create events in different timezones + now = datetime.datetime.now(UTC) timezones = [ - ( - "Asia/Kolkata", - datetime.datetime(2024, 6, 15, 14, 30, tzinfo=ZoneInfo("Asia/Kolkata")), - ), - ( - "America/New_York", - datetime.datetime( - 2024, 6, 15, 9, 0, tzinfo=ZoneInfo("America/New_York") - ), - ), - ( - "Europe/London", - datetime.datetime(2024, 6, 15, 15, 0, tzinfo=ZoneInfo("Europe/London")), - ), - ("UTC", datetime.datetime(2024, 6, 15, 12, 0, tzinfo=UTC)), + (tz, now.astimezone(ZoneInfo(tz))) + for tz in ["Asia/Kolkata", "America/New_York", "Europe/London", "UTC"] ] - for tz_name, dt in timezones: await Event(name=tz_name, start_time=dt).asave() @@ -99,7 +85,7 @@ class Event(Document): winter = Event( name="Winter", start_time=datetime.datetime( - 2024, 1, 15, 10, 0, tzinfo=ZoneInfo("America/New_York") + 2024, 1, 15, 10, 0, 30, 500000, tzinfo=ZoneInfo("America/New_York") ), ) await winter.asave() @@ -108,7 +94,7 @@ class Event(Document): summer = Event( name="Summer", start_time=datetime.datetime( - 2024, 7, 15, 10, 0, tzinfo=ZoneInfo("America/New_York") + 2024, 7, 15, 10, 0, 30, 500000, tzinfo=ZoneInfo("America/New_York") ), ) await summer.asave() @@ -140,19 +126,19 @@ class Event(Document): await Event( name="Early", start_time=datetime.datetime( - 2024, 6, 15, 8, 0, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 8, 0, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ), ).asave() # Late: 18:00 Asia/Kolkata (UTC+5:30) = 12:30 UTC await Event( name="Late", start_time=datetime.datetime( - 2024, 6, 15, 18, 0, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 18, 0, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ), ).asave() # Query by UTC time - should find only the Late event - utc_noon = datetime.datetime(2024, 6, 15, 12, 0, tzinfo=UTC) + utc_noon = datetime.datetime(2024, 6, 15, 12, 0, 30, 500000, tzinfo=UTC) events_after_noon = Event.aobjects(start_time__utc__gte=utc_noon) assert await events_after_noon.count() == 1 @@ -170,12 +156,12 @@ class Event(Document): # Create events in different timezones await Event( start_time=datetime.datetime( - 2024, 6, 15, 14, 30, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ) ).asave() await Event( start_time=datetime.datetime( - 2024, 6, 15, 9, 0, tzinfo=ZoneInfo("America/New_York") + 2024, 6, 15, 9, 0, 30, 500000, tzinfo=ZoneInfo("America/New_York") ) ).asave() @@ -198,13 +184,13 @@ class Event(Document): await Event( name="First", start_time=datetime.datetime( - 2024, 6, 15, 10, 0, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 10, 0, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ), # 04:30 UTC ).asave() await Event( name="Second", start_time=datetime.datetime( - 2024, 6, 15, 9, 0, tzinfo=ZoneInfo("America/New_York") + 2024, 6, 15, 9, 0, 30, 500000, tzinfo=ZoneInfo("America/New_York") ), # 13:00 UTC ).asave() @@ -247,7 +233,7 @@ class Event(Document): await Event.adrop_collection() # Naive datetime should raise validation error - naive_dt = datetime.datetime(2024, 6, 15, 14, 30) + naive_dt = datetime.datetime(2024, 6, 15, 14, 30, 30, 500000) event = Event(start_time=naive_dt) with pytest.raises(ValidationError): @@ -283,7 +269,7 @@ async def test_default_value(self): """Test default values work correctly.""" def get_default_time(): - return datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + return datetime.datetime(2024, 1, 1, 0, 0, 30, 500000, tzinfo=UTC) class Event(Document): start_time = AwareDateTimeField(default=get_default_time) @@ -307,7 +293,7 @@ class Event(Document): # Create event in Kolkata timezone kolkata_time = datetime.datetime( - 2024, 6, 15, 14, 30, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ) event = Event(start_time=kolkata_time) await event.asave() @@ -371,3 +357,100 @@ class Event(Document): assert desc_idx is not None assert desc_idx["key"][0] == ("start_time.utc", -1) + + async def test_iso_field_stored_in_mongodb(self): + """Test that the iso field is stored alongside utc and tz.""" + + class Event(Document): + start_time = AwareDateTimeField(required=True) + + await Event.adrop_collection() + + dt = datetime.datetime( + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") + ) + event = await Event(start_time=dt).asave() + + raw = await async_get_as_pymongo(event) + assert "iso" in raw["start_time"] + assert isinstance(raw["start_time"]["iso"], str) + # ISO string must round-trip back to the original datetime + assert datetime.datetime.fromisoformat(raw["start_time"]["iso"]) == dt + + async def test_microsecond_precision_preserved_via_iso(self): + """Test that microseconds survive the MongoDB round-trip via the iso field.""" + + class Event(Document): + start_time = AwareDateTimeField(required=True) + + await Event.adrop_collection() + + dt = datetime.datetime( + 2024, 3, 10, 8, 45, 17, 987654, tzinfo=ZoneInfo("Europe/London") + ) + await Event(start_time=dt).asave() + + retrieved = await Event.aobjects.first() + assert retrieved.start_time.second == 17 + assert retrieved.start_time.microsecond == 987654 + assert retrieved.start_time == dt + + async def test_iso_field_queryable(self): + """Test that start_time__iso can be used in queries.""" + + class Event(Document): + name = StringField() + start_time = AwareDateTimeField(required=True) + + await Event.adrop_collection() + + dt = datetime.datetime( + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") + ) + await Event(name="kolkata", start_time=dt).asave() + + iso_str = dt.isoformat() + result = await Event.aobjects(start_time__iso=iso_str).first() + assert result is not None + assert result.name == "kolkata" + + async def test_iso_field_contains_timezone_offset(self): + """Test that the stored iso string includes the UTC offset.""" + + class Event(Document): + start_time = AwareDateTimeField(required=True) + + await Event.adrop_collection() + + # Kolkata is UTC+5:30 + dt = datetime.datetime( + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") + ) + event = await Event(start_time=dt).asave() + + raw = await async_get_as_pymongo(event) + iso_str = raw["start_time"]["iso"] + # ISO string must encode the +05:30 offset + assert "+05:30" in iso_str or "05:30" in iso_str + + async def test_half_hour_offset_precision(self): + """Test that UTC+5:30 (Kolkata) microseconds convert correctly to/from UTC.""" + + class Event(Document): + start_time = AwareDateTimeField(required=True) + + await Event.adrop_collection() + + # 14:30:30.500000 Kolkata = 09:00:30.500000 UTC + dt = datetime.datetime( + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") + ) + await Event(start_time=dt).asave() + + retrieved = await Event.aobjects.first() + utc_dt = retrieved.start_time.astimezone(UTC) + + assert utc_dt.hour == 9 + assert utc_dt.minute == 0 + assert utc_dt.second == 30 + assert utc_dt.microsecond == 500000 diff --git a/tests/asynchronous/fields/test_datetime_field.py b/tests/asynchronous/fields/test_datetime_field.py index 287e8b391..f7c56308c 100644 --- a/tests/asynchronous/fields/test_datetime_field.py +++ b/tests/asynchronous/fields/test_datetime_field.py @@ -1,10 +1,15 @@ import datetime as dt +import unittest import pytest from mongoengine import * -from mongoengine.asynchronous import async_connect, connection -from tests.asynchronous.utils import MongoDBAsyncTestCase, async_get_as_pymongo +from mongoengine.asynchronous import async_connect +from tests.asynchronous.utils import ( + MongoDBAsyncTestCase, + async_get_as_pymongo, + reset_async_connections, +) from tests.utils import MONGO_TEST_DB try: @@ -239,14 +244,15 @@ class DTDoc(Document): dtd.validate() -class TestDateTimeTzAware(MongoDBAsyncTestCase): - async def test_datetime_tz_aware_mark_as_changed(self): - # Reset the connections - connection._connection_settings = {} - connection._connections = {} - connection._dbs = {} +class TestDateTimeTzAware(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + await reset_async_connections() + + async def asyncTearDown(self): + await reset_async_connections() - await async_connect(db=MONGO_TEST_DB, tz_aware=True) + async def test_datetime_tz_aware_mark_as_changed(self): + async_connect(db=MONGO_TEST_DB, tz_aware=True) class LogEntry(Document): time = DateTimeField() diff --git a/tests/asynchronous/fields/test_fields.py b/tests/asynchronous/fields/test_fields.py index b4d28c6a0..b7d850dc6 100644 --- a/tests/asynchronous/fields/test_fields.py +++ b/tests/asynchronous/fields/test_fields.py @@ -95,7 +95,7 @@ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) - created = DateTimeField(default=datetime.datetime.utcnow) + created = DateTimeField(default=lambda: datetime.datetime.now(UTC)) day = DateField(default=datetime.date.today) person = Person(name="Ross") @@ -175,7 +175,7 @@ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) - created = DateTimeField(default=datetime.datetime.utcnow) + created = DateTimeField(default=lambda: datetime.datetime.now(UTC)) # Trying setting values to None person = Person(name=None, age=None, userid=None, created=None) @@ -277,7 +277,7 @@ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) - created = DateTimeField(default=datetime.datetime.utcnow) + created = DateTimeField(default=lambda: datetime.datetime.now(UTC)) person = Person( name="Ross", @@ -342,7 +342,7 @@ class HandleNoneFields(Document): doc.str_fld = "spam ham egg" doc.int_fld = 42 doc.flt_fld = 4.2 - doc.com_dt_fld = datetime.datetime.utcnow() + doc.com_dt_fld = datetime.datetime.now(UTC) await doc.asave() res = await HandleNoneFields.aobjects(id=doc.id).update( diff --git a/tests/asynchronous/fields/test_file_field.py b/tests/asynchronous/fields/test_file_field.py index 8748dd70a..f86ba2c3c 100644 --- a/tests/asynchronous/fields/test_file_field.py +++ b/tests/asynchronous/fields/test_file_field.py @@ -7,7 +7,7 @@ import pytest from mongoengine import * -from mongoengine.asynchronous import async_register_connection, async_get_db +from mongoengine.asynchronous import async_get_db, async_register_connection from mongoengine.base.queryset import Q try: @@ -20,7 +20,6 @@ from tests.asynchronous.utils import MongoDBAsyncTestCase from tests.utils import MONGO_TEST_DB - require_pil = pytest.mark.skipif(not HAS_PIL, reason="PIL not installed") TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "mongoengine.png") @@ -517,7 +516,7 @@ class TestImage(Document): await t.image.adelete() async def test_file_multidb(self): - await async_register_connection("test_files", f"{MONGO_TEST_DB}_test_files") + async_register_connection("test_files", f"{MONGO_TEST_DB}_test_files") class TestFile(Document): name = StringField() @@ -542,11 +541,11 @@ class TestFile(Document): assert await test_file.the_file.aread() == b"Hello, World!" test_file = await TestFile.aobjects.first() - test_file.the_file.aput(b"Hello, World!") + await test_file.the_file.areplace(b"Hello, World 2!") await test_file.asave() test_file = await TestFile.aobjects.first() - assert await test_file.the_file.aread() == b"Hello, World!" + assert await test_file.the_file.aread() == b"Hello, World 2!" async def test_copyable(self): class PutFile(Document): diff --git a/tests/asynchronous/fields/test_sequence_field.py b/tests/asynchronous/fields/test_sequence_field.py index ecad72e8d..48d807818 100644 --- a/tests/asynchronous/fields/test_sequence_field.py +++ b/tests/asynchronous/fields/test_sequence_field.py @@ -239,10 +239,9 @@ class Bar(Base): assert "base.counter" in await self.db["mongoengine.counters"].find().distinct( "_id" ) - assert not ( - ("foo.counter" or "bar.counter") - in await self.db["mongoengine.counters"].find().distinct("_id") - ) + assert ("foo.counter" or "bar.counter") not in await self.db[ + "mongoengine.counters" + ].find().distinct("_id") assert foo.counter != bar.counter assert foo._fields["counter"].owner_document == Base assert bar._fields["counter"].owner_document == Base diff --git a/tests/asynchronous/queryset/test_field_list.py b/tests/asynchronous/queryset/test_field_list.py index c28288bd9..11c62e92c 100644 --- a/tests/asynchronous/queryset/test_field_list.py +++ b/tests/asynchronous/queryset/test_field_list.py @@ -1,17 +1,12 @@ -import unittest - import pytest from mongoengine import * -from mongoengine.asynchronous import async_connect, async_disconnect -from mongoengine.registry import _CollectionRegistry -from tests.asynchronous.utils import reset_async_connections -from tests.utils import MONGO_TEST_DB +from tests.asynchronous.utils import MongoDBAsyncTestCase -class TestOnlyExcludeAll(unittest.IsolatedAsyncioTestCase): +class TestOnlyExcludeAll(MongoDBAsyncTestCase): async def asyncSetUp(self): - await async_connect(db=MONGO_TEST_DB) + await super().asyncSetUp() class Person(Document): name = StringField() @@ -22,9 +17,7 @@ class Person(Document): self.Person = Person async def asyncTearDown(self): - await async_disconnect() - await reset_async_connections() - _CollectionRegistry.clear() + await super().asyncTearDown() def test_mixing_only_exclude(self): class MyDoc(Document): diff --git a/tests/asynchronous/queryset/test_geo.py b/tests/asynchronous/queryset/test_geo.py index f25827651..a04c82807 100644 --- a/tests/asynchronous/queryset/test_geo.py +++ b/tests/asynchronous/queryset/test_geo.py @@ -384,7 +384,7 @@ class Road(Document): name = StringField() line = LineStringField() - Road.adrop_collection() + await Road.adrop_collection() road = Road(name="66", line=[[40, 5], [41, 6]]) await road.asave() diff --git a/tests/asynchronous/queryset/test_modify.py b/tests/asynchronous/queryset/test_modify.py index ba62e7688..88b79a58a 100644 --- a/tests/asynchronous/queryset/test_modify.py +++ b/tests/asynchronous/queryset/test_modify.py @@ -6,10 +6,7 @@ ListField, StringField, ) -from mongoengine.asynchronous import async_connect, async_disconnect -from mongoengine.registry import _CollectionRegistry -from tests.asynchronous.utils import reset_async_connections -from tests.utils import MONGO_TEST_DB +from tests.asynchronous.utils import MongoDBAsyncTestCase class Doc(Document): @@ -17,15 +14,13 @@ class Doc(Document): value = IntField() -class TestOnlyExcludeAll(unittest.IsolatedAsyncioTestCase): +class TestOnlyExcludeAll(MongoDBAsyncTestCase): async def asyncSetUp(self): - await async_connect(db=MONGO_TEST_DB) + await super().asyncSetUp() await Doc.adrop_collection() async def asyncTearDown(self): - await async_disconnect() - await reset_async_connections() - _CollectionRegistry.clear() + await super().asyncTearDown() async def _assert_db_equal(self, docs): assert await (await Doc._aget_collection()).find().sort("id").to_list() == docs diff --git a/tests/asynchronous/queryset/test_pickable.py b/tests/asynchronous/queryset/test_pickable.py index 3c28d878a..7e646d291 100644 --- a/tests/asynchronous/queryset/test_pickable.py +++ b/tests/asynchronous/queryset/test_pickable.py @@ -1,9 +1,7 @@ import pickle from mongoengine import Document, IntField, StringField -from mongoengine.asynchronous import async_disconnect -from mongoengine.registry import _CollectionRegistry -from tests.asynchronous.utils import MongoDBAsyncTestCase, reset_async_connections +from tests.asynchronous.utils import MongoDBAsyncTestCase class Person(Document): @@ -21,12 +19,6 @@ async def asyncSetUp(self): await super().asyncSetUp() self.john = await Person.aobjects.create(name="John", age=21) - async def asyncTearDown(self): - await Person.adrop_collection() - await async_disconnect() - await reset_async_connections() - _CollectionRegistry.clear() - async def test_picke_simple_qs(self): qs = Person.aobjects.all() pickle.dumps(qs) diff --git a/tests/asynchronous/queryset/test_queryset.py b/tests/asynchronous/queryset/test_queryset.py index 972273c54..680725477 100644 --- a/tests/asynchronous/queryset/test_queryset.py +++ b/tests/asynchronous/queryset/test_queryset.py @@ -12,20 +12,20 @@ from mongoengine import * from mongoengine.base import LazyReference +from mongoengine.base.queryset import ( + CASCADE, + DENY, + NULLIFY, + PULL, + QuerySetManager, + queryset_manager, +) from mongoengine.context_managers import async_query_counter, switch_db from mongoengine.errors import InvalidQueryError from mongoengine.mongodb_support import ( async_get_mongodb_version, ) from mongoengine.pymongo_support import PYMONGO_VERSION -from mongoengine.base.queryset import ( - QuerySetManager, - queryset_manager, - CASCADE, - NULLIFY, - DENY, - PULL, -) from mongoengine.registry import _CollectionRegistry from tests.asynchronous.utils import ( async_db_ops_tracker, @@ -52,8 +52,10 @@ def get_key_compat(mongo_ver): class TestQueryset(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): - await async_connect(db=MONGO_TEST_DB) - await async_connect(db=f"{MONGO_TEST_DB}_2", alias="test2") + await reset_async_connections() + _CollectionRegistry.clear() + async_connect(db=MONGO_TEST_DB, alias="default") + async_connect(db=f"{MONGO_TEST_DB}_2", alias="test2") class PersonMeta(EmbeddedDocument): weight = IntField() @@ -72,10 +74,13 @@ class Person(Document): self.mongodb_version = await async_get_mongodb_version() async def asyncTearDown(self): - await async_disconnect(alias="default") - await async_disconnect(alias="test2") - await reset_async_connections() - _CollectionRegistry.clear() + await super().asyncTearDown() + try: + await async_disconnect(alias="default") + await async_disconnect(alias="test2") + finally: + await reset_async_connections() + _CollectionRegistry.clear() async def test_initialisation(self): """Ensure that a QuerySet is correctly initialised by AsyncQuerySetManager.""" @@ -229,7 +234,8 @@ async def test_skip(self): async def test___getitem___invalid_index(self): """Ensure slicing a queryset works as expected.""" with pytest.raises(TypeError): - await self.Person.aobjects().to_list()["a"] + results = await self.Person.aobjects().to_list() + assert results["a"] async def test_find_one(self): """Ensure that a query using find_one returns a valid result.""" @@ -1667,7 +1673,7 @@ class Dummy(Document): other = await Dummy(reference=base).asave() other2 = await Dummy(reference=other).asave() base.reference = other - base.asave() + await base.asave() await cat.adelete() @@ -2994,7 +3000,7 @@ class Link(Document): await Link.adrop_collection() - now = datetime.datetime.utcnow() + now = datetime.datetime.now(UTC) # Note: Test data taken from a custom Reddit homepage on # Fri, 12 Feb 2010 14:36:00 -0600. Link ordering should @@ -3616,7 +3622,7 @@ class News(Document): assert await qs1.to_list() == await qs2.to_list() async def test_distinct_handles_references_to_alias(self): - await async_register_connection("testdb", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb", f"{MONGO_TEST_DB}_2") class Bar(Document): text = StringField() @@ -3945,9 +3951,6 @@ class BlogPost(Document): await BlogPost.adrop_collection() - async def tearDown(self): - await self.Person.adrop_collection() - async def test_custom_querysets(self): """Ensure that custom QuerySet classes may be used.""" @@ -3964,7 +3967,7 @@ class Post(Document): assert not await Post.aobjects.not_empty() await Post().asave() - assert Post.aobjects.not_empty() + assert await Post.aobjects.not_empty() await Post.adrop_collection() @@ -5579,8 +5582,8 @@ class Test(Document): if not test: raise AssertionError("Cursor has data and returned False") - anext(queryset) - if not queryset.exists(): + await anext(queryset) + if not await queryset.exists(): raise AssertionError( "Cursor has data and it must returns True, even in the last item." ) diff --git a/tests/asynchronous/queryset/test_queryset_aggregation.py b/tests/asynchronous/queryset/test_queryset_aggregation.py index 623a5d3be..1c5b89b28 100644 --- a/tests/asynchronous/queryset/test_queryset_aggregation.py +++ b/tests/asynchronous/queryset/test_queryset_aggregation.py @@ -3,7 +3,7 @@ from mongoengine import Document, IntField, PointField, StringField from mongoengine.mongodb_support import async_get_mongodb_version -from tests.asynchronous.utils import async_db_ops_tracker, MongoDBAsyncTestCase +from tests.asynchronous.utils import MongoDBAsyncTestCase, async_db_ops_tracker from tests.utils import MONGO_TEST_DB diff --git a/tests/asynchronous/queryset/test_transform.py b/tests/asynchronous/queryset/test_transform.py index fb4eed4ad..1ffe7822c 100644 --- a/tests/asynchronous/queryset/test_transform.py +++ b/tests/asynchronous/queryset/test_transform.py @@ -2,8 +2,8 @@ from bson.son import SON from mongoengine import * -from mongoengine.common import _async_queryset_to_values from mongoengine.base.queryset import Q, transform +from mongoengine.common import _async_queryset_to_values from tests.asynchronous.utils import MongoDBAsyncTestCase diff --git a/tests/asynchronous/queryset/test_visitor.py b/tests/asynchronous/queryset/test_visitor.py index afea97392..fd737ca8c 100644 --- a/tests/asynchronous/queryset/test_visitor.py +++ b/tests/asynchronous/queryset/test_visitor.py @@ -6,17 +6,15 @@ from bson import ObjectId from mongoengine import * +from mongoengine.base.queryset import Q from mongoengine.common import _async_queryset_to_values from mongoengine.errors import InvalidQueryError -from mongoengine.base.queryset import Q -from mongoengine.registry import _CollectionRegistry -from tests.asynchronous.utils import reset_async_connections -from tests.utils import MONGO_TEST_DB +from tests.asynchronous.utils import MongoDBAsyncTestCase -class TestQ(unittest.IsolatedAsyncioTestCase): +class TestQ(MongoDBAsyncTestCase): async def asyncSetUp(self): - await async_connect(db=MONGO_TEST_DB) + await super().asyncSetUp() class Person(Document): name = StringField() @@ -27,9 +25,7 @@ class Person(Document): self.Person = Person async def asyncTearDown(self): - await async_disconnect() - await reset_async_connections() - _CollectionRegistry.clear() + await super().asyncTearDown() async def test_empty_q(self): """Ensure that empty Q objects won't hurt.""" diff --git a/tests/asynchronous/test_connection.py b/tests/asynchronous/test_connection.py index 2d762464f..3c7c875e0 100644 --- a/tests/asynchronous/test_connection.py +++ b/tests/asynchronous/test_connection.py @@ -2,25 +2,25 @@ import unittest import uuid -import pymongo -import pymongo.database -import pymongo.mongo_client -import pytest from bson import UuidRepresentation from bson.tz_util import utc -from pymongo import ReadPreference, AsyncMongoClient +import pymongo +from pymongo import AsyncMongoClient, ReadPreference from pymongo.asynchronous.database import AsyncDatabase +import pymongo.database from pymongo.errors import ( + ConnectionFailure, InvalidName, InvalidOperation, OperationFailure, ) +import pymongo.mongo_client +import pytest from mongoengine import ( DateTimeField, StringField, ) -from pymongo.errors import ConnectionFailure from mongoengine.asynchronous import ( async_connect, async_disconnect, @@ -46,42 +46,44 @@ def get_tz_awareness(connection_): return connection_.codec_options.tz_aware -class AsyncConnectionTest(unittest.IsolatedAsyncioTestCase): +class ConnectionTest(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): - await async_disconnect_all() + await reset_async_connections() async def asyncTearDown(self): - await async_disconnect_all() - await reset_async_connections() - _DocumentRegistry.clear() - _CollectionRegistry.clear() + try: + await async_disconnect_all() + finally: + await reset_async_connections() + _DocumentRegistry.clear() + _CollectionRegistry.clear() @pytest.mark.asyncio async def test_async_connect(self): """Ensure that the connect() method works properly.""" - await async_connect(MONGO_TEST_DB) + async_connect(MONGO_TEST_DB) - conn = await async_get_connection() + conn = async_get_connection() assert isinstance(conn, pymongo.AsyncMongoClient) db = await async_get_db() assert isinstance(db, AsyncDatabase) assert db.name == MONGO_TEST_DB - await async_connect(f"{MONGO_TEST_DB}_2", alias="testdb") - conn = await async_get_connection("testdb") + async_connect(f"{MONGO_TEST_DB}_2", alias="testdb") + conn = async_get_connection("testdb") assert isinstance(conn, pymongo.AsyncMongoClient) - await async_connect( + async_connect( f"{MONGO_TEST_DB}_2", alias="testdb3", mongo_client_class=pymongo.AsyncMongoClient, ) - conn = await async_get_connection("testdb") + conn = async_get_connection("testdb") assert isinstance(conn, pymongo.AsyncMongoClient) @pytest.mark.asyncio - async def test_async_connect_disconnect_works_properly(self): + async def test_connect_disconnect_works_properly(self): class History1(Document): name = StringField() meta = {"db_alias": "db1"} @@ -90,8 +92,8 @@ class History2(Document): name = StringField() meta = {"db_alias": "db2"} - await async_connect(f"{MONGO_TEST_DB}_db1", alias="db1") - await async_connect(f"{MONGO_TEST_DB}_db2", alias="db2") + async_connect(f"{MONGO_TEST_DB}_db1", alias="db1") + async_connect(f"{MONGO_TEST_DB}_db2", alias="db2") await History1.adrop_collection() await History2.adrop_collection() @@ -115,8 +117,8 @@ class History2(Document): with pytest.raises(ConnectionFailure): await History2.aobjects().as_pymongo().to_list() - await async_connect(f"{MONGO_TEST_DB}_db1", alias="db1") - await async_connect(f"{MONGO_TEST_DB}_db2", alias="db2") + async_connect(f"{MONGO_TEST_DB}_db1", alias="db1") + async_connect(f"{MONGO_TEST_DB}_db2", alias="db2") assert await History1.aobjects().as_pymongo().to_list() == [ {"_id": h.id, "name": "default"} @@ -126,7 +128,7 @@ class History2(Document): ] @pytest.mark.asyncio - async def test_async_connect_different_documents_to_different_database(self): + async def test_connect_different_documents_to_different_database(self): class History(Document): name = StringField() @@ -138,9 +140,9 @@ class History2(Document): name = StringField() meta = {"db_alias": "db2"} - await async_connect(MONGO_TEST_DB) - await async_connect(f"{MONGO_TEST_DB}_db1", alias="db1") - await async_connect(f"{MONGO_TEST_DB}_db2", alias="db2") + async_connect(MONGO_TEST_DB) + async_connect(f"{MONGO_TEST_DB}_db1", alias="db1") + async_connect(f"{MONGO_TEST_DB}_db2", alias="db2") await History.adrop_collection() await History1.adrop_collection() @@ -169,22 +171,22 @@ class History2(Document): ] @pytest.mark.asyncio - async def test_async_connect_fails_if_connect_2_times_with_default_alias(self): - await async_connect(MONGO_TEST_DB) + async def test_connect_fails_if_connect_2_times_with_default_alias(self): + async_connect(MONGO_TEST_DB) with pytest.raises(ConnectionFailure) as exc_info: - await async_connect(f"{MONGO_TEST_DB}_2") + async_connect(f"{MONGO_TEST_DB}_2") assert ( "A different connection with alias `default` was already registered. Use async_disconnect() first" == str(exc_info.value) ) @pytest.mark.asyncio - async def test_async_connect_fails_if_async_connect_2_times_with_custom_alias(self): - await async_connect(MONGO_TEST_DB, alias="alias1") + async def test_connect_fails_if_async_connect_2_times_with_custom_alias(self): + async_connect(MONGO_TEST_DB, alias="alias1") with pytest.raises(ConnectionFailure) as exc_info: - await async_connect(f"{MONGO_TEST_DB}_2", alias="alias1") + async_connect(f"{MONGO_TEST_DB}_2", alias="alias1") assert ( "A different connection with alias `alias1` was already registered. Use async_disconnect() first" @@ -192,70 +194,68 @@ async def test_async_connect_fails_if_async_connect_2_times_with_custom_alias(se ) @pytest.mark.asyncio - async def test_async_connect_fails_if_similar_connection_settings_arent_defined_the_same_way( + async def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way( self, ): """Intended to keep the detection function simple but robust""" db_name = MONGO_TEST_DB db_alias = "alias1" - await async_connect(db=db_name, alias=db_alias, host="localhost", port=27017) + async_connect(db=db_name, alias=db_alias, host="localhost", port=27017) with pytest.raises(ConnectionFailure): - await async_connect( - host="mongodb://localhost:27017/%s" % db_name, alias=db_alias - ) + async_connect(host="mongodb://localhost:27017/%s" % db_name, alias=db_alias) @pytest.mark.asyncio - async def test_async_connect_passes_silently_connect_multiple_times_with_same_config( + async def test_connect_passes_silently_connect_multiple_times_with_same_config( self, ): # test default async connection to `test` - await async_connect() - await async_connect() + async_connect() + async_connect() assert len(connection._connections) == 1 - await async_connect(f"{MONGO_TEST_DB}01", alias="test01") - await async_connect(f"{MONGO_TEST_DB}01", alias="test01") + async_connect(f"{MONGO_TEST_DB}01", alias="test01") + async_connect(f"{MONGO_TEST_DB}01", alias="test01") assert len(connection._connections) == 2 - await async_connect( + async_connect( host=f"mongodb://localhost:27017/{MONGO_TEST_DB}02", alias="test02" ) - await async_connect( + async_connect( host=f"mongodb://localhost:27017/{MONGO_TEST_DB}02", alias="test02" ) assert len(connection._connections) == 3 @pytest.mark.asyncio - async def test_async_connect_with_invalid_db_name(self): + async def test_connect_with_invalid_db_name(self): """Ensure that the async_connect() method fails fast if the db name is invalid""" with pytest.raises(InvalidName): - await async_connect("mongodb://localhost") + async_connect("mongodb://localhost") @pytest.mark.asyncio - async def test_async_connect_with_db_name_external(self): + async def test_connect_with_db_name_external(self): """Ensure that async_connect() works if the db name is $external""" """Ensure that the async_connect() method works properly.""" - await async_connect("$external") + async_connect("$external") - conn = await async_get_connection() + conn = async_get_connection() assert isinstance(conn, AsyncMongoClient) db = await async_get_db() assert isinstance(db, AsyncDatabase) assert db.name == "$external" - await async_connect("$external", alias="testdb") - conn = await async_get_connection("testdb") + async_connect("$external", alias="testdb") + conn = async_get_connection("testdb") assert isinstance(conn, AsyncMongoClient) @pytest.mark.asyncio - async def test_async_connect_with_invalid_db_name_type(self): + async def test_connect_with_invalid_db_name_type(self): """Ensure that the async_connect() method fails fast if db name has invalid type""" with pytest.raises(TypeError): non_string_db_name = ["e. g. list instead of a string"] - await async_connect(non_string_db_name) + async_connect(non_string_db_name) @pytest.mark.asyncio - async def test_async_disconnect_cleans_globals(self): + async def test_disconnect_cleans_globals(self): """Ensure that the async_disconnect() method cleans the globals objects""" await reset_async_connections() await async_disconnect_all() @@ -263,7 +263,7 @@ async def test_async_disconnect_cleans_globals(self): dbs = connection._dbs connection_settings = connection._connection_settings - await async_connect(MONGO_TEST_DB) + async_connect(MONGO_TEST_DB) assert len(connections._connections) == 1 assert len(dbs) == 0 @@ -281,11 +281,11 @@ class TestDoc(Document): assert len(connection_settings) == 0 @pytest.mark.asyncio - async def test_async_disconnect_cleans_cached_collection_attribute_in_document( + async def test_disconnect_cleans_cached_collection_attribute_in_document( self, ): """Ensure that the async_disconnect() method works properly""" - await async_connect(MONGO_TEST_DB) + async_connect(MONGO_TEST_DB) class History(Document): pass @@ -306,7 +306,7 @@ class History(Document): assert "You have not defined a default connection" == str(exc_info.value) @pytest.mark.asyncio - async def test_async_connect_disconnect_works_on_same_document(self): + async def test_connect_disconnect_works_on_same_document(self): """Ensure that the async_connect/async_disconnect works properly with a single Document""" db1 = f"{MONGO_TEST_DB}_db1" db2 = f"{MONGO_TEST_DB}_db2" @@ -317,7 +317,7 @@ async def test_async_connect_disconnect_works_on_same_document(self): await client.drop_database(db2) # Save in db1 - await async_connect(db1) + async_connect(db1) class User(Document): name = StringField() @@ -330,7 +330,7 @@ class User(Document): await User(name="Wont work").asave() # Save in db2 - await async_connect(db2) + async_connect(db2) user2 = await User(name="Bob is in db2").asave() await async_disconnect() @@ -338,19 +338,18 @@ class User(Document): assert db1_users == [{"_id": user1.id, "name": "John is in db1"}] db2_users = await client[db2].user.find().to_list() assert db2_users == [{"_id": user2.id, "name": "Bob is in db2"}] + await client.close() @pytest.mark.asyncio - async def test_async_disconnect_silently_pass_if_alias_does_not_exist(self): + async def test_disconnect_silently_pass_if_alias_does_not_exist(self): assert len(connection._connections) == 0 await async_disconnect(alias="not_exist") @pytest.mark.asyncio - async def test_async_disconnect_does_not_close_client_used_by_another_alias(self): - client1 = await async_connect(alias="disconnect_reused_client_test_1") - client2 = await async_connect(alias="disconnect_reused_client_test_2") - client3 = await async_connect( - alias="disconnect_reused_client_test_3", maxPoolSize=10 - ) + async def test_disconnect_does_not_close_client_used_by_another_alias(self): + client1 = async_connect(alias="disconnect_reused_client_test_1") + client2 = async_connect(alias="disconnect_reused_client_test_2") + client3 = async_connect(alias="disconnect_reused_client_test_3", maxPoolSize=10) assert client1 is client2 assert client1 is not client3 await client1.admin.command("ping") @@ -372,14 +371,14 @@ async def test_async_disconnect_does_not_close_client_used_by_another_alias(self await client3.admin.command("ping") @pytest.mark.asyncio - async def test_async_disconnect_all(self): + async def test_disconnect_all(self): await reset_async_connections() await async_disconnect_all() dbs = connection._dbs connection_settings = connection._connection_settings - await async_connect(MONGO_TEST_DB) - await async_connect(f"{MONGO_TEST_DB}_2", alias="db1") + async_connect(MONGO_TEST_DB) + async_connect(f"{MONGO_TEST_DB}_2", alias="db1") class History(Document): pass @@ -419,26 +418,26 @@ class History1(Document): await History1.aobjects.first() @pytest.mark.asyncio - async def test_async_disconnect_all_silently_pass_if_no_connection_exist(self): + async def test_disconnect_all_silently_pass_if_no_connection_exist(self): await async_disconnect_all() @pytest.mark.asyncio async def test_sharing_async_connections(self): """Ensure that connections are shared when the connection settings are exactly the same""" - await async_connect(MONGO_TEST_DB, alias="testdb1") - expected_connection = await async_get_connection("testdb1") + async_connect(MONGO_TEST_DB, alias="testdb1") + expected_connection = async_get_connection("testdb1") - await async_connect(MONGO_TEST_DB, alias="testdb2") - actual_connection = await async_get_connection("testdb2") + async_connect(MONGO_TEST_DB, alias="testdb2") + actual_connection = async_get_connection("testdb2") await expected_connection.server_info() assert expected_connection == actual_connection @pytest.mark.asyncio - async def test_async_connect_uri(self): + async def test_connect_uri(self): """Ensure that the async_connect() method works properly with URIs.""" - c = await async_connect(db=MONGO_TEST_DB, alias="admin") + c = async_connect(db=MONGO_TEST_DB, alias="admin") admin_username = f"admin_{uuid.uuid4().hex[:8]}" user_username = f"user_{uuid.uuid4().hex[:8]}" @@ -451,17 +450,17 @@ async def test_async_connect_uri(self): ) adminadmin_settings["username"] = admin_username adminadmin_settings["password"] = "password" - ca = await async_connect(db=MONGO_TEST_DB, alias="adminadmin") + ca = async_connect(db=MONGO_TEST_DB, alias="adminadmin") await ca.admin.command( "createUser", user_username, pwd="password", roles=["dbOwner"] ) - await async_connect( + async_connect( f"{MONGO_TEST_DB}_testdb_uri", host=f"mongodb://username:password@localhost/{MONGO_TEST_DB}", ) - conn = await async_get_connection() + conn = async_get_connection() assert isinstance(conn, pymongo.AsyncMongoClient) db = await async_get_db() @@ -472,13 +471,13 @@ async def test_async_connect_uri(self): await c.admin.command("dropUser", admin_username) @pytest.mark.asyncio - async def test_async_connect_uri_without_db(self): + async def test_connect_uri_without_db(self): """Ensure the async_connect() method works properly if the URI doesn't include a database name. """ - await async_connect(MONGO_TEST_DB, host="mongodb://localhost/") + async_connect(MONGO_TEST_DB, host="mongodb://localhost/") - conn = await async_get_connection() + conn = async_get_connection() assert isinstance(conn, pymongo.AsyncMongoClient) db = await async_get_db() @@ -486,13 +485,13 @@ async def test_async_connect_uri_without_db(self): assert db.name == MONGO_TEST_DB @pytest.mark.asyncio - async def test_async_connect_uri_default_db(self): + async def test_connect_uri_default_db(self): """Ensure async_connect() defaults to the right database name if the URI and the database_name don't explicitly specify it. """ - await async_connect(host="mongodb://localhost/") + async_connect(host="mongodb://localhost/") - conn = await async_get_connection() + conn = async_get_connection() assert isinstance(conn, pymongo.AsyncMongoClient) db = await async_get_db() @@ -504,7 +503,7 @@ async def test_uri_without_credentials_doesnt_override_async_conn_settings(self) """Ensure async_connect() uses the username and password params if the URI doesn't explicitly specify them. """ - await async_connect( + async_connect( host=f"mongodb://localhost/{MONGO_TEST_DB}", username="user", password="pass", @@ -523,18 +522,18 @@ async def test_uri_without_credentials_doesnt_override_async_conn_settings(self) await async_get_db() @pytest.mark.asyncio - async def test_async_connect_uri_with_authsource(self): + async def test_connect_uri_with_authsource(self): """Ensure that the async_connect() method works well with the `authSource` option in the URI. """ # Create users - c = await async_connect(MONGO_TEST_DB) + c = async_connect(MONGO_TEST_DB) username = f"user_{uuid.uuid4().hex[:8]}" await c.admin.command("createUser", username, pwd="password", roles=["dbOwner"]) # Authentication fails without "authSource" - test_conn = await async_connect( + test_conn = async_connect( MONGO_TEST_DB, alias="test1", host=f"mongodb://{username}:password@localhost/{MONGO_TEST_DB}", @@ -543,7 +542,7 @@ async def test_async_connect_uri_with_authsource(self): await test_conn.server_info() # Authentication succeeds with "authSource" - authd_conn = await async_connect( + authd_conn = async_connect( MONGO_TEST_DB, alias="test2", host=( @@ -560,13 +559,13 @@ async def test_async_connect_uri_with_authsource(self): @pytest.mark.asyncio async def test_register_async_connection(self): """Ensure that async connections with different aliases may be registered.""" - await async_register_connection( + async_register_connection( "testdb", f"{MONGO_TEST_DB}_2", mongo_client_class=AsyncMongoClient ) with pytest.raises(ConnectionFailure): - await async_get_connection() - conn = await async_get_connection("testdb") + async_get_connection() + conn = async_get_connection("testdb") assert isinstance(conn, pymongo.AsyncMongoClient) db = await async_get_db("testdb") @@ -576,7 +575,7 @@ async def test_register_async_connection(self): @pytest.mark.asyncio async def test_register_async_connection_defaults(self): """Ensure that defaults are used when the host and port are None.""" - await async_register_connection( + async_register_connection( "testdb", MONGO_TEST_DB, host=None, @@ -584,29 +583,29 @@ async def test_register_async_connection_defaults(self): mongo_client_class=AsyncMongoClient, ) - conn = await async_get_connection("testdb") + conn = async_get_connection("testdb") assert isinstance(conn, pymongo.AsyncMongoClient) @pytest.mark.asyncio - async def test_async_connection_kwargs(self): + async def test_connection_kwargs(self): """Ensure that async connection kwargs get passed to pymongo.""" - await async_connect(MONGO_TEST_DB, alias="t1", tz_aware=True) - conn = await async_get_connection("t1") + async_connect(MONGO_TEST_DB, alias="t1", tz_aware=True) + conn = async_get_connection("t1") assert get_tz_awareness(conn) - await async_connect(f"{MONGO_TEST_DB}_2", alias="t2") - conn = await async_get_connection("t2") + async_connect(f"{MONGO_TEST_DB}_2", alias="t2") + conn = async_get_connection("t2") assert not get_tz_awareness(conn) @pytest.mark.asyncio - async def test_async_connection_pool_via_kwarg(self): + async def test_connection_pool_via_kwarg(self): """Ensure we can specify a max connection pool size using an async connection kwarg. """ pool_size_kwargs = {"maxpoolsize": 100} - conn = await async_connect( + conn = async_connect( MONGO_TEST_DB, alias="max_pool_size_via_kwarg", **pool_size_kwargs ) if PYMONGO_VERSION >= (4,): @@ -615,11 +614,11 @@ async def test_async_connection_pool_via_kwarg(self): assert conn.max_pool_size == 100 @pytest.mark.asyncio - async def test_async_connection_pool_via_uri(self): + async def test_connection_pool_via_uri(self): """Ensure we can specify a max connection pool size using an option in an async connection URI. """ - conn = await async_connect( + conn = async_connect( host="mongodb://localhost/test?maxpoolsize=100", alias="max_pool_size_via_uri", ) @@ -629,33 +628,33 @@ async def test_async_connection_pool_via_uri(self): assert conn.max_pool_size == 100 @pytest.mark.asyncio - async def test_async_write_concern(self): + async def test_write_concern(self): """Ensure write concern can be specified in connect() via a kwarg or as part of the connection URI. """ - conn1 = await async_connect( + conn1 = async_connect( alias="conn1", host="mongodb://localhost/testing?w=1&journal=true" ) - conn2 = await async_connect("testing", alias="conn2", w=1, journal=True) + conn2 = async_connect("testing", alias="conn2", w=1, journal=True) assert conn1.write_concern.document == {"w": 1, "j": True} assert conn2.write_concern.document == {"w": 1, "j": True} @pytest.mark.asyncio - async def test_async_connect_with_replicaset_via_uri(self): + async def test_connect_with_replicaset_via_uri(self): """Ensure connect() works when specifying a replicaSet via the MongoDB URI. """ - await async_connect(host="mongodb://localhost/test?replicaSet=local-rs") + async_connect(host="mongodb://localhost/test?replicaSet=local-rs") db = await async_get_db() assert isinstance(db, AsyncDatabase) assert db.name == "test" @pytest.mark.asyncio - async def test_async_connect_with_replicaset_via_kwargs(self): + async def test_connect_with_replicaset_via_kwargs(self): """Ensure async_connect() works when specifying a replicaSet via the connection kwargs """ - c = await async_connect(replicaset="local-rs") + c = async_connect(replicaset="local-rs") if hasattr(c, "_AsyncMongoClient__options"): assert c._AsyncMongoClient__options.replica_set_name == "local-rs" else: # pymongo >= 4.9 @@ -665,8 +664,8 @@ async def test_async_connect_with_replicaset_via_kwargs(self): assert db.name == "test" @pytest.mark.asyncio - async def test_async_connect_tz_aware(self): - await async_connect(MONGO_TEST_DB, tz_aware=True) + async def test_connect_tz_aware(self): + async_connect(MONGO_TEST_DB, tz_aware=True) d = datetime.datetime(2010, 5, 5, tzinfo=utc) class DateDoc(Document): @@ -679,27 +678,27 @@ class DateDoc(Document): assert d == date_doc.the_date @pytest.mark.asyncio - async def test_async_read_preference_from_parse(self): - conn = await async_connect( + async def test_read_preference_from_parse(self): + conn = async_connect( host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred" ) assert conn.read_preference == ReadPreference.SECONDARY_PREFERRED @pytest.mark.asyncio async def test_multiple_async_connection_settings(self): - await async_connect( + async_connect( MONGO_TEST_DB, alias="t1", host="localhost", read_preference=ReadPreference.PRIMARY, ) - await async_connect( + async_connect( f"{MONGO_TEST_DB}_2", alias="t2", host="127.0.0.1", read_preference=ReadPreference.PRIMARY_PREFERRED, ) - mongo_connections = connection._connections + mongo_connections: dict[str, AsyncMongoClient] = connection._connections assert len(mongo_connections.items()) == 2 assert "t1" in mongo_connections.keys() assert "t2" in mongo_connections.keys() @@ -707,8 +706,8 @@ async def test_multiple_async_connection_settings(self): # Handle PyMongo 3+ Async Connection (lazily established) # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. # Purposely not catching exception to fail the test if thrown. - mongo_connections["t1"].server_info() - mongo_connections["t2"].server_info() + await mongo_connections["t1"].server_info() + await mongo_connections["t2"].server_info() assert (await mongo_connections["t1"].address)[0] == "localhost" assert (await mongo_connections["t2"].address)[0] in ( "localhost", @@ -721,29 +720,29 @@ async def test_multiple_async_connection_settings(self): assert mongo_connections["t1"] is not mongo_connections["t2"] @pytest.mark.asyncio - async def test_async_connect_2_databases_uses_same_client_if_only_dbname_differs( + async def test_connect_2_databases_uses_same_client_if_only_dbname_differs( self, ): - c1 = await async_connect(alias="testdb1", db="testdb1") - c2 = await async_connect(alias="testdb2", db="testdb2") + c1 = async_connect(alias="testdb1", db="testdb1") + c2 = async_connect(alias="testdb2", db="testdb2") assert c1 is c2 @pytest.mark.asyncio - async def test_async_connect_2_databases_uses_different_client_if_different_parameters( + async def test_connect_2_databases_uses_different_client_if_different_parameters( self, ): - c1 = await async_connect( + c1 = async_connect( alias="testdb1", db="testdb1", username="u1", password="pass" ) - c2 = await async_connect( + c2 = async_connect( alias="testdb2", db="testdb2", username="u2", password="pass" ) assert c1 is not c2 @pytest.mark.asyncio - async def test_async_connect_uri_uuidrepresentation_set_in_uri(self): + async def test_connect_uri_uuidrepresentation_set_in_uri(self): rand = random_str() - tmp_conn = await async_connect( + tmp_conn = async_connect( alias=rand, host=f"mongodb://localhost:27017/{rand}?uuidRepresentation=csharpLegacy", ) @@ -754,11 +753,9 @@ async def test_async_connect_uri_uuidrepresentation_set_in_uri(self): await async_disconnect(rand) @pytest.mark.asyncio - async def test_async_connect_uri_uuidrepresentation_set_as_arg(self): + async def test_connect_uri_uuidrepresentation_set_as_arg(self): rand = random_str() - tmp_conn = await async_connect( - alias=rand, db=rand, uuidRepresentation="javaLegacy" - ) + tmp_conn = async_connect(alias=rand, db=rand, uuidRepresentation="javaLegacy") assert ( tmp_conn.options.codec_options.uuid_representation == pymongo.common._UUID_REPRESENTATIONS["javaLegacy"] @@ -766,11 +763,11 @@ async def test_async_connect_uri_uuidrepresentation_set_as_arg(self): await async_disconnect(rand) @pytest.mark.asyncio - async def test_async_connect_uri_uuidrepresentation_set_both_arg_and_uri_arg_prevail( + async def test_connect_uri_uuidrepresentation_set_both_arg_and_uri_arg_prevail( self, ): rand = random_str() - tmp_conn = await async_connect( + tmp_conn = async_connect( alias=rand, host=f"mongodb://localhost:27017/{rand}?uuidRepresentation=csharpLegacy", uuidRepresentation="javaLegacy", @@ -782,13 +779,13 @@ async def test_async_connect_uri_uuidrepresentation_set_both_arg_and_uri_arg_pre await async_disconnect(rand) @pytest.mark.asyncio - async def test_async_connect_uuid_representation_defaults_to_unspecified(self): + async def test_connect_uuid_representation_defaults_to_unspecified(self): """ PyMongo >= 4 defaults uuidRepresentation to UNSPECIFIED. Old behavior ('pythonLegacy') is deprecated and removed. """ rand = random_str() - tmp_conn = await async_connect(alias=rand, db=rand) + tmp_conn = async_connect(alias=rand, db=rand) # Assert new PyMongo 4.x behavior assert ( diff --git a/tests/asynchronous/test_context_managers.py b/tests/asynchronous/test_context_managers.py index 7732b497a..de7f25cc8 100644 --- a/tests/asynchronous/test_context_managers.py +++ b/tests/asynchronous/test_context_managers.py @@ -3,28 +3,28 @@ import random import pytest -from pymongo.errors import OperationFailure, InvalidOperation +from pymongo.errors import InvalidOperation, OperationFailure from pymongo.read_concern import ReadConcern from mongoengine import * from mongoengine.asynchronous import ( - async_register_connection, - async_get_db, async_connect, + async_get_db, + async_register_connection, ) -from mongoengine.session import _get_session from mongoengine.context_managers import ( + async_query_counter, no_sub_classes, + run_in_transaction, set_read_write_concern, set_write_concern, switch_collection, switch_db, - async_query_counter, - run_in_transaction, ) from mongoengine.pymongo_support import async_count_documents +from mongoengine.session import _get_session from tests.asynchronous.utils import MongoDBAsyncTestCase -from tests.utils import requires_mongodb_gte_44, MONGO_TEST_DB +from tests.utils import MONGO_TEST_DB, requires_mongodb_gte_44 class TestRollbackError(Exception): @@ -75,7 +75,7 @@ class User(Document): assert original_write_concern.document == collection.write_concern.document async def test_switch_db_context_manager(self): - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") class Group(Document): name = StringField() @@ -100,7 +100,7 @@ class Group(Document): assert 1 == await Group.aobjects.count() async def test_switch_collection_context_manager(self): - await async_register_connection(alias="testdb-1", db=f"{MONGO_TEST_DB}_2") + async_register_connection(alias="testdb-1", db=f"{MONGO_TEST_DB}_2") class Group(Document): name = StringField() @@ -274,7 +274,7 @@ async def issue_1_find_query(): async def test_query_counter_alias(self): """query_counter works properly with db aliases?""" # Register a connection with db_alias testdb-1 - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") class A(Document): """Uses default db_alias""" @@ -423,8 +423,8 @@ class A(Document): assert await A.aobjects.count() == 0 async def test_transaction_updates_across_databases(self): - await async_connect(MONGO_TEST_DB) - await async_connect(f"{MONGO_TEST_DB}_2", "test2") + async_connect(MONGO_TEST_DB) + async_connect(f"{MONGO_TEST_DB}_2", "test2") class A(Document): name = StringField() @@ -450,8 +450,8 @@ class B(Document): async def test_collection_creation_via_upserts_across_databases_in_transaction( self, ): - await async_connect(MONGO_TEST_DB) - await async_connect(f"{MONGO_TEST_DB}_test2", "test2") + async_connect(MONGO_TEST_DB) + async_connect(f"{MONGO_TEST_DB}_test2", "test2") class A(Document): name = StringField() @@ -481,8 +481,8 @@ class B(Document): async def test_an_exception_raised_in_transactions_across_databases_rolls_back_updates( self, ): - await async_connect(MONGO_TEST_DB) - await async_connect(f"{MONGO_TEST_DB}_2", "test2") + async_connect(MONGO_TEST_DB) + async_connect(f"{MONGO_TEST_DB}_2", "test2") class A(Document): name = StringField() diff --git a/tests/asynchronous/test_dereference.py b/tests/asynchronous/test_dereference.py index 23334ceef..d176dfebf 100644 --- a/tests/asynchronous/test_dereference.py +++ b/tests/asynchronous/test_dereference.py @@ -3,25 +3,13 @@ from bson import DBRef, ObjectId from mongoengine import * -from mongoengine.asynchronous import ( - async_connect, - async_register_connection, - async_disconnect_all, -) +from mongoengine.asynchronous import async_register_connection from mongoengine.context_managers import async_query_counter -from tests.asynchronous.utils import reset_async_connections +from tests.asynchronous.utils import MongoDBAsyncTestCase from tests.utils import MONGO_TEST_DB -class FieldTest(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - self.db = await async_connect(db=MONGO_TEST_DB) - - async def asyncTearDown(self): - await self.db.drop_database(MONGO_TEST_DB) - await async_disconnect_all() - await reset_async_connections() - +class FieldTest(MongoDBAsyncTestCase): async def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced.""" @@ -1328,7 +1316,7 @@ class Group(Document): async def test_objectid_reference_across_databases(self): # mongoenginetest - Is default connection alias from setUp() # Register Aliases - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") class User(Document): name = StringField() diff --git a/tests/asynchronous/test_replicaset_connection.py b/tests/asynchronous/test_replicaset_connection.py index 418d21d79..3c5c18588 100644 --- a/tests/asynchronous/test_replicaset_connection.py +++ b/tests/asynchronous/test_replicaset_connection.py @@ -1,30 +1,26 @@ import unittest -from pymongo import MongoClient, ReadPreference +from pymongo import AsyncMongoClient, ReadPreference -import mongoengine from mongoengine.asynchronous.connection import ConnectionFailure, async_connect +from tests.asynchronous.utils import reset_async_connections from tests.utils import MONGO_TEST_DB -CONN_CLASS = MongoClient +CONN_CLASS = AsyncMongoClient READ_PREF = ReadPreference.SECONDARY class ConnectionTest(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): - mongoengine.asynchronous.connection._connection_settings = {} - mongoengine.asynchronous.connection._connections = {} - mongoengine.asynchronous.connection._dbs = {} + await reset_async_connections() async def asyncTearDown(self): - mongoengine.asynchronous.connection._connection_settings = {} - mongoengine.asynchronous.connection._connections = {} - mongoengine.asynchronous.connection._dbs = {} + await reset_async_connections() async def test_replicaset_uri_passes_read_preference(self): """Requires a replica set called "rs" on port 27017""" try: - conn = await async_connect( + conn = async_connect( db=MONGO_TEST_DB, host=f"mongodb://localhost/{MONGO_TEST_DB}?replicaSet=rs", read_preference=READ_PREF, @@ -33,7 +29,6 @@ async def test_replicaset_uri_passes_read_preference(self): return if not isinstance(conn, CONN_CLASS): - # really??? - return + raise TypeError(f"Expected {CONN_CLASS}, got {type(conn)}") assert conn.read_preference == READ_PREF diff --git a/tests/asynchronous/test_signals.py b/tests/asynchronous/test_signals.py index d065684d6..e667c6fcc 100644 --- a/tests/asynchronous/test_signals.py +++ b/tests/asynchronous/test_signals.py @@ -27,7 +27,8 @@ async def get_signal_output(fn, *args, **kwargs): return signal_output async def asyncSetUp(self): - await async_connect(db=MONGO_TEST_DB) + await reset_async_connections() + async_connect(db=MONGO_TEST_DB) class Author(Document): # Make the id deterministic for easier testing @@ -215,46 +216,48 @@ async def post_bulk_insert(cls, sender, documents, **kwargs): signals.post_bulk_insert.connect(Post.post_bulk_insert, sender=Post) async def asyncTearDown(self): - signals.pre_init.disconnect(self.Author.pre_init) - signals.post_init.disconnect(self.Author.post_init) - signals.post_delete.disconnect(self.Author.post_delete) - signals.pre_delete.disconnect(self.Author.pre_delete) - signals.post_save.disconnect(self.Author.post_save) - signals.pre_save_post_validation.disconnect( - self.Author.pre_save_post_validation - ) - signals.pre_save.disconnect(self.Author.pre_save) - signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) - signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) - - signals.post_delete.disconnect(self.Another.post_delete) - signals.pre_delete.disconnect(self.Another.pre_delete) - - signals.post_save.disconnect(self.ExplicitId.post_save) - - signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert) - signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert) - - # Check that all our signals got disconnected properly. - post_signals = ( - len(signals.pre_init.receivers), - len(signals.post_init.receivers), - len(signals.pre_save.receivers), - len(signals.pre_save_post_validation.receivers), - len(signals.post_save.receivers), - len(signals.pre_delete.receivers), - len(signals.post_delete.receivers), - len(signals.pre_bulk_insert.receivers), - len(signals.post_bulk_insert.receivers), - ) - - await self.ExplicitId.aobjects.delete() - - # Note that there is a chance that the following assert fails in case - # some receivers (eventually created in other tests) - # gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect) - assert self.pre_signals == post_signals - await reset_async_connections() + try: + signals.pre_init.disconnect(self.Author.pre_init) + signals.post_init.disconnect(self.Author.post_init) + signals.post_delete.disconnect(self.Author.post_delete) + signals.pre_delete.disconnect(self.Author.pre_delete) + signals.post_save.disconnect(self.Author.post_save) + signals.pre_save_post_validation.disconnect( + self.Author.pre_save_post_validation + ) + signals.pre_save.disconnect(self.Author.pre_save) + signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) + signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) + + signals.post_delete.disconnect(self.Another.post_delete) + signals.pre_delete.disconnect(self.Another.pre_delete) + + signals.post_save.disconnect(self.ExplicitId.post_save) + + signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert) + signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert) + + # Check that all our signals got disconnected properly. + post_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.pre_save_post_validation.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers), + len(signals.pre_bulk_insert.receivers), + len(signals.post_bulk_insert.receivers), + ) + + await self.ExplicitId.aobjects.delete() + + # Note that there is a chance that the following assert fails in case + # some receivers (eventually created in other tests) + # gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect) + assert self.pre_signals == post_signals + finally: + await reset_async_connections() async def test_model_signals(self): """Model saves should throw some signals.""" @@ -424,8 +427,8 @@ async def test_signals_with_switch_collection(self): assert await self.get_signal_output(ei.asave) == ["Is created"] async def test_signals_with_switch_db(self): - await async_connect(MONGO_TEST_DB) - await async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") + async_connect(MONGO_TEST_DB) + async_register_connection("testdb-1", f"{MONGO_TEST_DB}_2") ei = self.ExplicitId(id=123) ei.switch_db("testdb-1") diff --git a/tests/asynchronous/utils.py b/tests/asynchronous/utils.py index 41a62d6f8..edfd2a1f9 100644 --- a/tests/asynchronous/utils.py +++ b/tests/asynchronous/utils.py @@ -6,16 +6,14 @@ import pytest from mongoengine.asynchronous import ( - async_disconnect_all, async_connect, - async_get_db, async_disconnect, + async_get_db, ) from mongoengine.base import _DocumentRegistry from mongoengine.context_managers import async_query_counter -from mongoengine.mongodb_support import get_mongodb_version, async_get_mongodb_version +from mongoengine.mongodb_support import async_get_mongodb_version, get_mongodb_version from mongoengine.registry import _CollectionRegistry - from tests.utils import MONGO_TEST_DB @@ -25,13 +23,27 @@ class MongoDBAsyncTestCase(unittest.IsolatedAsyncioTestCase): """ async def asyncSetUp(self): - await async_disconnect_all() - self._connection = await async_connect(db=MONGO_TEST_DB) + # 1. Clear out everything from previous runs + await reset_async_connections() + _DocumentRegistry.clear() + _CollectionRegistry.clear() + + # 2. Establish the fresh connection + self._connection = async_connect(db=MONGO_TEST_DB) await self._connection.drop_database(MONGO_TEST_DB) self.db = await async_get_db() async def asyncTearDown(self): - await self._connection.drop_database(MONGO_TEST_DB) + # 1. Grab the connection safely (handles cases where setup failed mid-way) + conn = getattr(self, "_connection", None) + if conn: + await conn.drop_database(MONGO_TEST_DB) + + # 2. Break references immediately so Python collects them while the loop is still alive + self._connection = None + self.db = None + + # 3. Purge registries and disconnect await async_disconnect() await reset_async_connections() _DocumentRegistry.clear() @@ -121,8 +133,8 @@ async def get_ops(self): async def reset_async_connections(): from mongoengine.asynchronous.connection import ( - _connections, _connection_settings, + _connections, _dbs, ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..c25575651 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,26 @@ +from pymongo import MongoClient +import pytest + +from tests.utils import MONGO_TEST_DB + + +def cleanup_databases(): + print("Cleaning up test databases...") + with MongoClient("localhost", 27017) as client: + db_names = client.list_database_names() + + for db_name in db_names: + if db_name.startswith(MONGO_TEST_DB): + print(f"Dropping test database: {db_name}") + client.drop_database(db_name) + + +@pytest.fixture(scope="session", autouse=True) +def cleanup_mongoengine_databases(): + """ + Session-scoped fixture that runs automatically after all tests finish. + Finds and drops all databases starting with 'mongoengine' using MongoClient. + """ + cleanup_databases() + yield + cleanup_databases() diff --git a/tests/synchronous/all_warnings/test_warnings.py b/tests/synchronous/all_warnings/test_warnings.py index ee72c3369..d70ed4b1e 100644 --- a/tests/synchronous/all_warnings/test_warnings.py +++ b/tests/synchronous/all_warnings/test_warnings.py @@ -4,29 +4,26 @@ top level and called first by the test suite. """ -import unittest import warnings from mongoengine import * -from tests.synchronous.utils import reset_connections from mongoengine.base.common import _document_registry -from tests.utils import MONGO_TEST_DB +from tests.synchronous.utils import MongoDBTestCase -class TestAllWarnings(unittest.TestCase): +class TestAllWarnings(MongoDBTestCase): def setUp(self): - connect(db=MONGO_TEST_DB) self.warning_list = [] self.showwarning_default = warnings.showwarning warnings.showwarning = self.append_to_warning_list + super().setUp() def append_to_warning_list(self, message, category, *args): self.warning_list.append({"message": message, "category": category}) def tearDown(self): - # restore default handling of warnings warnings.showwarning = self.showwarning_default - reset_connections() + super().tearDown() def test_document_collection_syntax_warning(self): class NonAbstractBase(Document): diff --git a/tests/synchronous/document/test_class_methods.py b/tests/synchronous/document/test_class_methods.py index 063ede9a6..ec20024cd 100644 --- a/tests/synchronous/document/test_class_methods.py +++ b/tests/synchronous/document/test_class_methods.py @@ -1,16 +1,14 @@ import unittest from mongoengine import * -from mongoengine.synchronous.connection import get_db -from mongoengine.pymongo_support import list_collection_names from mongoengine.base.queryset import NULLIFY, PULL -from tests.utils import MONGO_TEST_DB +from mongoengine.pymongo_support import list_collection_names +from tests.synchronous.utils import MongoDBTestCase -class TestClassMethods(unittest.TestCase): +class TestClassMethods(MongoDBTestCase): def setUp(self): - connect(db=MONGO_TEST_DB) - self.db = get_db() + super().setUp() class Person(Document): name = StringField() @@ -23,8 +21,7 @@ class Person(Document): self.Person = Person def tearDown(self): - for collection in list_collection_names(self.db): - self.db.drop_collection(collection) + super().tearDown() def test_definition(self): """Ensure that document may be defined using fields.""" diff --git a/tests/synchronous/document/test_delta.py b/tests/synchronous/document/test_delta.py index 0c90fae7c..03eb9c4d1 100644 --- a/tests/synchronous/document/test_delta.py +++ b/tests/synchronous/document/test_delta.py @@ -3,7 +3,6 @@ from bson import SON from mongoengine import * -from mongoengine.pymongo_support import list_collection_names from tests.synchronous.utils import MongoDBTestCase, get_as_pymongo @@ -21,10 +20,6 @@ class Person(Document): self.Person = Person - def tearDown(self): - for collection in list_collection_names(self.db): - self.db.drop_collection(collection) - def test_delta(self): self.delta(Document) self.delta(DynamicDocument) diff --git a/tests/synchronous/document/test_indexes.py b/tests/synchronous/document/test_indexes.py index cf52b971f..3aac6597b 100644 --- a/tests/synchronous/document/test_indexes.py +++ b/tests/synchronous/document/test_indexes.py @@ -6,22 +6,20 @@ from mongoengine import * from mongoengine.errors import NotUniqueError -from mongoengine.registry import _CollectionRegistry -from mongoengine.synchronous.connection import get_db from mongoengine.mongodb_support import ( MONGODB_42, MONGODB_80, get_mongodb_version, ) from mongoengine.pymongo_support import PYMONGO_VERSION -from tests.synchronous.utils import reset_connections +from tests.synchronous.utils import MongoDBTestCase from tests.utils import MONGO_TEST_DB -class TestIndexes(unittest.TestCase): +class TestIndexes(MongoDBTestCase): def setUp(self): - self.connection = connect(db=MONGO_TEST_DB) - self.db = get_db() + super().setUp() + self.connection = self._connection class Person(Document): name = StringField() @@ -34,11 +32,10 @@ class Person(Document): self.Person = Person def tearDown(self): - self.Person.adrop_collection() - self.connection.drop_database(self.db) - disconnect_all() - reset_connections() - _CollectionRegistry.clear() + try: + self.Person.drop_collection() + finally: + super().tearDown() def test_indexes_document(self): """Ensure that indexes are used when meta[indexes] is specified for diff --git a/tests/synchronous/document/test_instance.py b/tests/synchronous/document/test_instance.py index 01a26f3b3..55c30a92f 100644 --- a/tests/synchronous/document/test_instance.py +++ b/tests/synchronous/document/test_instance.py @@ -14,9 +14,8 @@ from mongoengine import * from mongoengine import signals from mongoengine.base import _DocumentRegistry -from mongoengine.registry import _CollectionRegistry -from mongoengine.synchronous.connection import get_db -from mongoengine.context_managers import query_counter, switch_db, switch_collection +from mongoengine.base.queryset import CASCADE, DENY, NULLIFY, PULL, Q +from mongoengine.context_managers import query_counter, switch_collection, switch_db from mongoengine.errors import ( FieldDoesNotExist, InvalidDocumentError, @@ -32,7 +31,8 @@ PYMONGO_VERSION, list_collection_names, ) -from mongoengine.base.queryset import NULLIFY, Q, PULL, CASCADE, DENY +from mongoengine.registry import _CollectionRegistry +from mongoengine.synchronous.connection import get_db from tests import fixtures from tests.fixtures import ( PickleDynamicEmbedded, @@ -50,6 +50,15 @@ ) from tests.utils import MONGO_TEST_DB +try: + # Python 3.11+ + from datetime import UTC +except ImportError: + # Python ≤ 3.10 + from datetime import timezone + + UTC = timezone.utc + TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "../fields/mongoengine.png") @@ -3766,7 +3775,7 @@ class Test(Document): def test_default_values_dont_get_override_upon_save_when_only_is_used(self): class Person(Document): - created_on = DateTimeField(default=lambda: datetime.utcnow()) + created_on = DateTimeField(default=lambda: datetime.now(UTC)) name = StringField() p = Person(name="alon") @@ -3780,7 +3789,7 @@ class Person(Document): assert orig_created_on == p3.created_on class Person(Document): - created_on = DateTimeField(default=lambda: datetime.utcnow()) + created_on = DateTimeField(default=lambda: datetime.now(UTC)) name = StringField() height = IntField(default=189) diff --git a/tests/synchronous/document/test_timeseries_collection.py b/tests/synchronous/document/test_timeseries_collection.py index a4d38cefa..df9d45944 100644 --- a/tests/synchronous/document/test_timeseries_collection.py +++ b/tests/synchronous/document/test_timeseries_collection.py @@ -7,17 +7,23 @@ Document, FloatField, StringField, - connect, - get_db, ) -from mongoengine.synchronous.connection import disconnect -from tests.utils import requires_mongodb_gte_50, MONGO_TEST_DB +from tests.synchronous.utils import MongoDBTestCase +from tests.utils import requires_mongodb_gte_50 +try: + # Python 3.11+ + from datetime import UTC +except ImportError: + # Python ≤ 3.10 + from datetime import timezone -class TestTimeSeriesCollections(unittest.TestCase): + UTC = timezone.utc + + +class TestTimeSeriesCollections(MongoDBTestCase): def setUp(self): - connect(db=MONGO_TEST_DB) - self.db = get_db() + super().setUp() class SensorData(Document): timestamp = DateTimeField(required=True) @@ -41,10 +47,12 @@ def test_get_db(self): assert self.db == db def tearDown(self): - for collection_name in self.db.list_collection_names(): - if not collection_name.startswith("system."): - self.db.drop_collection(collection_name) - disconnect() + try: + for collection_name in self.db.list_collection_names(): + if not collection_name.startswith("system."): + self.db.drop_collection(collection_name) + finally: + super().tearDown() def test_definition(self): """Ensure that document may be defined using fields.""" @@ -84,7 +92,7 @@ def test_insert_document_into_timeseries_collection(self): assert collection_name in self.db.list_collection_names() # Insert a document and ensure it was inserted - self.SensorData(timestamp=datetime.utcnow(), temperature=23.4).save() + self.SensorData(timestamp=datetime.now(UTC), temperature=23.4).save() assert collection.count_documents({}) == 1 @requires_mongodb_gte_50 @@ -98,7 +106,7 @@ def test_timeseries_expiration(self): assert options.get("timeseries", {}) is not None assert options["expireAfterSeconds"] == 1 - self.SensorData(timestamp=datetime.utcnow(), temperature=23.4).save() + self.SensorData(timestamp=datetime.now(UTC), temperature=23.4).save() assert collection.count_documents({}) == 1 @@ -144,7 +152,7 @@ def test_timeseries_data_insertion_order(self): self.SensorData._get_collection() # Insert documents out of order - now = datetime.utcnow() + now = datetime.now(UTC) self.SensorData(timestamp=now, temperature=23.4).save() self.SensorData(timestamp=now - timedelta(seconds=5), temperature=22.0).save() self.SensorData(timestamp=now + timedelta(seconds=5), temperature=24.0).save() @@ -164,7 +172,7 @@ def test_timeseries_query_by_time_range(self): self.SensorData._get_collection_name() self.SensorData._get_collection() - now = datetime.utcnow() + now = datetime.now(UTC) self.SensorData(timestamp=now - timedelta(seconds=10), temperature=22.0).save() self.SensorData(timestamp=now - timedelta(seconds=5), temperature=23.0).save() self.SensorData(timestamp=now, temperature=24.0).save() diff --git a/tests/synchronous/fields/test_aware_datetime_field.py b/tests/synchronous/fields/test_aware_datetime_field.py index 2e16177eb..8bbb7efaf 100644 --- a/tests/synchronous/fields/test_aware_datetime_field.py +++ b/tests/synchronous/fields/test_aware_datetime_field.py @@ -30,9 +30,7 @@ class Event(Document): Event.drop_collection() # Create event with Asia/Kolkata timezone - kolkata_time = datetime.datetime( - 2024, 6, 15, 14, 30, tzinfo=ZoneInfo("Asia/Kolkata") - ) + kolkata_time = datetime.datetime.now().astimezone(ZoneInfo("Asia/Kolkata")) event = Event(start_time=kolkata_time) event.save() @@ -40,6 +38,7 @@ class Event(Document): raw = get_as_pymongo(event) assert "start_time" in raw assert "utc" in raw["start_time"] + assert "iso" in raw["start_time"] assert "tz" in raw["start_time"] assert raw["start_time"]["tz"] == "Asia/Kolkata" @@ -59,24 +58,11 @@ class Event(Document): Event.drop_collection() # Create events in different timezones + now = datetime.datetime.now(UTC) timezones = [ - ( - "Asia/Kolkata", - datetime.datetime(2024, 6, 15, 14, 30, tzinfo=ZoneInfo("Asia/Kolkata")), - ), - ( - "America/New_York", - datetime.datetime( - 2024, 6, 15, 9, 0, tzinfo=ZoneInfo("America/New_York") - ), - ), - ( - "Europe/London", - datetime.datetime(2024, 6, 15, 15, 0, tzinfo=ZoneInfo("Europe/London")), - ), - ("UTC", datetime.datetime(2024, 6, 15, 12, 0, tzinfo=UTC)), + (tz, now.astimezone(ZoneInfo(tz))) + for tz in ["Asia/Kolkata", "America/New_York", "Europe/London", "UTC"] ] - for tz_name, dt in timezones: Event(name=tz_name, start_time=dt).save() @@ -99,7 +85,7 @@ class Event(Document): winter = Event( name="Winter", start_time=datetime.datetime( - 2024, 1, 15, 10, 0, tzinfo=ZoneInfo("America/New_York") + 2024, 1, 15, 10, 0, 30, 500000, tzinfo=ZoneInfo("America/New_York") ), ) winter.save() @@ -108,7 +94,7 @@ class Event(Document): summer = Event( name="Summer", start_time=datetime.datetime( - 2024, 7, 15, 10, 0, tzinfo=ZoneInfo("America/New_York") + 2024, 7, 15, 10, 0, 30, 500000, tzinfo=ZoneInfo("America/New_York") ), ) summer.save() @@ -140,19 +126,19 @@ class Event(Document): Event( name="Early", start_time=datetime.datetime( - 2024, 6, 15, 8, 0, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 8, 0, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ), ).save() # Late: 18:00 Asia/Kolkata (UTC+5:30) = 12:30 UTC Event( name="Late", start_time=datetime.datetime( - 2024, 6, 15, 18, 0, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 18, 0, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ), ).save() # Query by UTC time - should find only the Late event - utc_noon = datetime.datetime(2024, 6, 15, 12, 0, tzinfo=UTC) + utc_noon = datetime.datetime(2024, 6, 15, 12, 0, 30, 500000, tzinfo=UTC) events_after_noon = Event.objects(start_time__utc__gte=utc_noon) assert events_after_noon.count() == 1 @@ -169,12 +155,12 @@ class Event(Document): # Create events in different timezones Event( start_time=datetime.datetime( - 2024, 6, 15, 14, 30, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ) ).save() Event( start_time=datetime.datetime( - 2024, 6, 15, 9, 0, tzinfo=ZoneInfo("America/New_York") + 2024, 6, 15, 9, 0, 30, 500000, tzinfo=ZoneInfo("America/New_York") ) ).save() @@ -196,13 +182,13 @@ class Event(Document): Event( name="First", start_time=datetime.datetime( - 2024, 6, 15, 10, 0, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 10, 0, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ), # 04:30 UTC ).save() Event( name="Second", start_time=datetime.datetime( - 2024, 6, 15, 9, 0, tzinfo=ZoneInfo("America/New_York") + 2024, 6, 15, 9, 0, 30, 500000, tzinfo=ZoneInfo("America/New_York") ), # 13:00 UTC ).save() @@ -245,7 +231,7 @@ class Event(Document): Event.drop_collection() # Naive datetime should raise validation error - naive_dt = datetime.datetime(2024, 6, 15, 14, 30) + naive_dt = datetime.datetime(2024, 6, 15, 14, 30, 30, 500000) event = Event(start_time=naive_dt) with pytest.raises(ValidationError): @@ -281,7 +267,7 @@ def test_default_value(self): """Test default values work correctly.""" def get_default_time(): - return datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + return datetime.datetime(2024, 1, 1, 0, 0, 30, 500000, tzinfo=UTC) class Event(Document): start_time = AwareDateTimeField(default=get_default_time) @@ -305,7 +291,7 @@ class Event(Document): # Create event in Kolkata timezone kolkata_time = datetime.datetime( - 2024, 6, 15, 14, 30, tzinfo=ZoneInfo("Asia/Kolkata") + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") ) event = Event(start_time=kolkata_time) event.save() @@ -369,3 +355,98 @@ class Event(Document): assert desc_idx is not None assert desc_idx["key"][0] == ("start_time.utc", -1) + + def test_iso_field_stored_in_mongodb(self): + """Test that the iso field is stored alongside utc and tz.""" + + class Event(Document): + start_time = AwareDateTimeField(required=True) + + Event.drop_collection() + + dt = datetime.datetime( + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") + ) + event = Event(start_time=dt) + event.save() + + raw = get_as_pymongo(event) + assert "iso" in raw["start_time"] + assert isinstance(raw["start_time"]["iso"], str) + assert datetime.datetime.fromisoformat(raw["start_time"]["iso"]) == dt + + def test_microsecond_precision_preserved_via_iso(self): + """Test that microseconds survive the MongoDB round-trip via the iso field.""" + + class Event(Document): + start_time = AwareDateTimeField(required=True) + + Event.drop_collection() + + dt = datetime.datetime( + 2024, 3, 10, 8, 45, 17, 987654, tzinfo=ZoneInfo("Europe/London") + ) + Event(start_time=dt).save() + + retrieved = Event.objects.first() + assert retrieved.start_time.second == 17 + assert retrieved.start_time.microsecond == 987654 + assert retrieved.start_time == dt + + def test_iso_field_queryable(self): + """Test that start_time__iso can be used in queries.""" + + class Event(Document): + name = StringField() + start_time = AwareDateTimeField(required=True) + + Event.drop_collection() + + dt = datetime.datetime( + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") + ) + Event(name="kolkata", start_time=dt).save() + + iso_str = dt.isoformat() + result = Event.objects(start_time__iso=iso_str).first() + assert result is not None + assert result.name == "kolkata" + + def test_iso_field_contains_timezone_offset(self): + """Test that the stored iso string includes the UTC offset.""" + + class Event(Document): + start_time = AwareDateTimeField(required=True) + + Event.drop_collection() + + dt = datetime.datetime( + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") + ) + event = Event(start_time=dt) + event.save() + + raw = get_as_pymongo(event) + iso_str = raw["start_time"]["iso"] + assert "+05:30" in iso_str or "05:30" in iso_str + + def test_half_hour_offset_precision(self): + """Test that UTC+5:30 (Kolkata) microseconds convert correctly to/from UTC.""" + + class Event(Document): + start_time = AwareDateTimeField(required=True) + + Event.drop_collection() + + dt = datetime.datetime( + 2024, 6, 15, 14, 30, 30, 500000, tzinfo=ZoneInfo("Asia/Kolkata") + ) + Event(start_time=dt).save() + + retrieved = Event.objects.first() + utc_dt = retrieved.start_time.astimezone(UTC) + + assert utc_dt.hour == 9 + assert utc_dt.minute == 0 + assert utc_dt.second == 30 + assert utc_dt.microsecond == 500000 diff --git a/tests/synchronous/fields/test_datetime_field.py b/tests/synchronous/fields/test_datetime_field.py index 70c9a78e6..e4d2bcf22 100644 --- a/tests/synchronous/fields/test_datetime_field.py +++ b/tests/synchronous/fields/test_datetime_field.py @@ -1,10 +1,10 @@ import datetime as dt +import unittest import pytest from mongoengine import * -from mongoengine.synchronous import connection -from tests.synchronous.utils import MongoDBTestCase, get_as_pymongo +from tests.synchronous.utils import MongoDBTestCase, get_as_pymongo, reset_connections from tests.utils import MONGO_TEST_DB try: @@ -12,6 +12,15 @@ except ImportError: dateutil = None +try: + # Python 3.11+ + from datetime import UTC +except ImportError: + # Python ≤ 3.10 + from datetime import timezone + + UTC = timezone.utc + class TestDateTimeField(MongoDBTestCase): def test_datetime_from_empty_string(self): @@ -46,9 +55,9 @@ def test_default_value_utcnow(self): """ class Person(Document): - created = DateTimeField(default=dt.datetime.utcnow) + created = DateTimeField(default=lambda: dt.datetime.now(UTC)) - utcnow = dt.datetime.utcnow() + utcnow = dt.datetime.now(UTC) person = Person() person.validate() person_created_t0 = person.created @@ -227,13 +236,14 @@ class DTDoc(Document): dtd.validate() -class TestDateTimeTzAware(MongoDBTestCase): - def test_datetime_tz_aware_mark_as_changed(self): - # Reset the connections - connection._connection_settings = {} - connection._connections = {} - connection._dbs = {} +class TestDateTimeTzAware(unittest.TestCase): + def setUp(self): + reset_connections() + + def tearDown(self): + reset_connections() + def test_datetime_tz_aware_mark_as_changed(self): connect(db=MONGO_TEST_DB, tz_aware=True) class LogEntry(Document): diff --git a/tests/synchronous/fields/test_fields.py b/tests/synchronous/fields/test_fields.py index 5300d6c6e..40fedfffc 100644 --- a/tests/synchronous/fields/test_fields.py +++ b/tests/synchronous/fields/test_fields.py @@ -35,6 +35,15 @@ from mongoengine.errors import DeprecatedError from tests.synchronous.utils import MongoDBTestCase +try: + # Python 3.11+ + from datetime import UTC +except ImportError: + # Python ≤ 3.10 + from datetime import timezone + + UTC = timezone.utc + class TestField(MongoDBTestCase): def test_constructor_set_historical_behavior_is_kept(self): @@ -86,7 +95,7 @@ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) - created = DateTimeField(default=datetime.datetime.utcnow) + created = DateTimeField(default=lambda: datetime.datetime.now(UTC)) day = DateField(default=datetime.date.today) person = Person(name="Ross") @@ -166,7 +175,7 @@ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) - created = DateTimeField(default=datetime.datetime.utcnow) + created = DateTimeField(default=lambda: datetime.datetime.now(UTC)) # Trying setting values to None person = Person(name=None, age=None, userid=None, created=None) @@ -200,7 +209,7 @@ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) - created = DateTimeField(default=datetime.datetime.utcnow) + created = DateTimeField(default=lambda: datetime.datetime.now(UTC)) person = Person() person.name = None @@ -268,7 +277,7 @@ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) - created = DateTimeField(default=datetime.datetime.utcnow) + created = DateTimeField(default=lambda: datetime.datetime.now(UTC)) person = Person( name="Ross", @@ -333,7 +342,7 @@ class HandleNoneFields(Document): doc.str_fld = "spam ham egg" doc.int_fld = 42 doc.flt_fld = 4.2 - doc.com_dt_fld = datetime.datetime.utcnow() + doc.com_dt_fld = datetime.datetime.now(UTC) doc.save() res = HandleNoneFields.objects(id=doc.id).update( @@ -369,7 +378,7 @@ class HandleNoneFields(Document): doc.str_fld = "spam ham egg" doc.int_fld = 42 doc.flt_fld = 4.2 - doc.comp_dt_fld = datetime.datetime.utcnow() + doc.comp_dt_fld = datetime.datetime.now(UTC) doc.save() # Unset all the fields diff --git a/tests/synchronous/fields/test_file_field.py b/tests/synchronous/fields/test_file_field.py index ee06a9f86..1b2834425 100644 --- a/tests/synchronous/fields/test_file_field.py +++ b/tests/synchronous/fields/test_file_field.py @@ -18,7 +18,7 @@ except ImportError: HAS_PIL = False -from tests.synchronous.utils import MongoDBTestCase, MONGO_TEST_DB +from tests.synchronous.utils import MONGO_TEST_DB, MongoDBTestCase require_pil = pytest.mark.skipif(not HAS_PIL, reason="PIL not installed") diff --git a/tests/synchronous/fields/test_sequence_field.py b/tests/synchronous/fields/test_sequence_field.py index 6932c3b58..ab1468255 100644 --- a/tests/synchronous/fields/test_sequence_field.py +++ b/tests/synchronous/fields/test_sequence_field.py @@ -237,10 +237,9 @@ class Bar(Base): foo.save() assert "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") - assert not ( - ("foo.counter" or "bar.counter") - in self.db["mongoengine.counters"].find().distinct("_id") - ) + assert ("foo.counter" or "bar.counter") not in self.db[ + "mongoengine.counters" + ].find().distinct("_id") assert foo.counter != bar.counter assert foo._fields["counter"].owner_document == Base assert bar._fields["counter"].owner_document == Base diff --git a/tests/synchronous/fixtures.py b/tests/synchronous/fixtures.py index d22abe02c..c35fccbcc 100644 --- a/tests/synchronous/fixtures.py +++ b/tests/synchronous/fixtures.py @@ -2,7 +2,6 @@ from mongoengine import * from mongoengine import signals - from tests.fixtures import PickleEmbedded diff --git a/tests/synchronous/queryset/test_modify.py b/tests/synchronous/queryset/test_modify.py index a8da8da88..7dc7d737a 100644 --- a/tests/synchronous/queryset/test_modify.py +++ b/tests/synchronous/queryset/test_modify.py @@ -5,9 +5,8 @@ IntField, ListField, StringField, - connect, ) -from tests.utils import MONGO_TEST_DB +from tests.synchronous.utils import MongoDBTestCase class Doc(Document): @@ -15,9 +14,9 @@ class Doc(Document): value = IntField() -class TestFindAndModify(unittest.TestCase): +class TestFindAndModify(MongoDBTestCase): def setUp(self): - connect(db=MONGO_TEST_DB) + super().setUp() Doc.drop_collection() def _assert_db_equal(self, docs): diff --git a/tests/synchronous/queryset/test_queryset.py b/tests/synchronous/queryset/test_queryset.py index 2abadcdfa..4de9492db 100644 --- a/tests/synchronous/queryset/test_queryset.py +++ b/tests/synchronous/queryset/test_queryset.py @@ -11,30 +11,39 @@ from mongoengine import * from mongoengine.base import LazyReference -from mongoengine.registry import _CollectionRegistry -from mongoengine.synchronous import QuerySet, QuerySetNoCache -from mongoengine.synchronous.connection import get_db +from mongoengine.base.queryset import ( + CASCADE, + DENY, + NULLIFY, + PULL, + QuerySetManager, + queryset_manager, +) from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError from mongoengine.mongodb_support import ( get_mongodb_version, ) from mongoengine.pymongo_support import PYMONGO_VERSION -from mongoengine.base.queryset import ( - QuerySetManager, - queryset_manager, - NULLIFY, - CASCADE, - DENY, - PULL, -) +from mongoengine.registry import _CollectionRegistry +from mongoengine.synchronous import QuerySet, QuerySetNoCache +from mongoengine.synchronous.connection import get_db from mongoengine.synchronous.queryset.base import BaseQuerySet from tests.synchronous.utils import db_ops_tracker, get_as_pymongo, reset_connections from tests.utils import ( + MONGO_TEST_DB, requires_mongodb_gte_42, requires_mongodb_gte_44, ) -from tests.utils import MONGO_TEST_DB + +try: + # Python 3.11+ + from datetime import UTC +except ImportError: + # Python ≤ 3.10 + from datetime import timezone + + UTC = timezone.utc def get_key_compat(mongo_ver): @@ -45,6 +54,8 @@ def get_key_compat(mongo_ver): class TestQueryset(unittest.TestCase): def setUp(self): + reset_connections() + _CollectionRegistry.clear() connect(db=MONGO_TEST_DB) connect(db=f"{MONGO_TEST_DB}_2", alias="test2") @@ -63,11 +74,13 @@ class Person(Document): self.mongodb_version = get_mongodb_version() - async def tearDown(self): - disconnect(alias="default") - disconnect(alias="test2") - reset_connections() - _CollectionRegistry.clear() + def tearDown(self): + try: + disconnect(alias="default") + disconnect(alias="test2") + finally: + reset_connections() + _CollectionRegistry.clear() def test_initialisation(self): """Ensure that a QuerySet is correctly initialised by QuerySetManager.""" @@ -1559,7 +1572,7 @@ class BlogPost(Document): meta = {"ordering": ["-published_date"]} BlogPost.objects.create( - title="whatever", published_date=datetime.datetime.utcnow() + title="whatever", published_date=datetime.datetime.now(UTC) ) with db_ops_tracker() as q: @@ -3024,7 +3037,7 @@ class Link(Document): Link.drop_collection() - now = datetime.datetime.utcnow() + now = datetime.datetime.now(UTC) # Note: Test data taken from a custom Reddit homepage on # Fri, 12 Feb 2010 14:36:00 -0600. Link ordering should @@ -3942,9 +3955,6 @@ class BlogPost(Document): BlogPost.drop_collection() - def tearDown(self): - self.Person.drop_collection() - def test_custom_querysets(self): """Ensure that custom QuerySet classes may be used.""" diff --git a/tests/synchronous/queryset/test_visitor.py b/tests/synchronous/queryset/test_visitor.py index 5b7593275..aa4cc8f26 100644 --- a/tests/synchronous/queryset/test_visitor.py +++ b/tests/synchronous/queryset/test_visitor.py @@ -6,14 +6,15 @@ from bson import ObjectId from mongoengine import * -from mongoengine.errors import InvalidQueryError from mongoengine.base.queryset import Q +from mongoengine.errors import InvalidQueryError +from tests.synchronous.utils import MongoDBTestCase from tests.utils import MONGO_TEST_DB -class TestQ(unittest.TestCase): +class TestQ(MongoDBTestCase): def setUp(self): - connect(db=MONGO_TEST_DB) + super().setUp() class Person(Document): name = StringField() diff --git a/tests/synchronous/test_connection.py b/tests/synchronous/test_connection.py index beb6fe2dd..af20c0244 100644 --- a/tests/synchronous/test_connection.py +++ b/tests/synchronous/test_connection.py @@ -2,20 +2,19 @@ import unittest import uuid -import pymongo -import pymongo.database -import pymongo.mongo_client -import pytest from bson.tz_util import utc +import pymongo from pymongo import MongoClient, ReadPreference +import pymongo.database from pymongo.errors import ( InvalidName, InvalidOperation, OperationFailure, ) +import pymongo.mongo_client from pymongo.read_preferences import Secondary +import pytest -import mongoengine.synchronous.connection from mongoengine import ( DateTimeField, Document, @@ -25,7 +24,9 @@ register_connection, ) from mongoengine.base import _DocumentRegistry +from mongoengine.pymongo_support import PYMONGO_VERSION from mongoengine.registry import _CollectionRegistry +import mongoengine.synchronous.connection from mongoengine.synchronous.connection import ( ConnectionFailure, _get_connection_settings, @@ -33,7 +34,7 @@ get_connection, get_db, ) -from mongoengine.pymongo_support import PYMONGO_VERSION +from tests.synchronous.utils import reset_connections from tests.utils import MONGO_TEST_DB @@ -46,20 +47,16 @@ def get_tz_awareness(connection): class ConnectionTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - disconnect_all() - - @classmethod - def tearDownClass(cls): - disconnect_all() + def setUp(self): + reset_connections() def tearDown(self): - mongoengine.synchronous.connection._connection_settings = {} - mongoengine.synchronous.connection._connections = {} - mongoengine.synchronous.connection._dbs = {} - _DocumentRegistry.clear() - _CollectionRegistry.clear() + try: + disconnect_all() + finally: + reset_connections() + _DocumentRegistry.clear() + _CollectionRegistry.clear() def test_connect(self): """Ensure that the connect() method works properly.""" @@ -331,6 +328,8 @@ class User(Document): db2_users = list(client[db2].user.find()) assert db2_users == [{"_id": user2.id, "name": "Bob is in db2"}] + client.close() + def test_disconnect_silently_pass_if_alias_does_not_exist(self): connections = mongoengine.synchronous.connection._connections assert len(connections) == 0 diff --git a/tests/synchronous/test_context_managers.py b/tests/synchronous/test_context_managers.py index b95b007dc..77ce9ecda 100644 --- a/tests/synchronous/test_context_managers.py +++ b/tests/synchronous/test_context_managers.py @@ -9,8 +9,6 @@ from pymongo.read_concern import ReadConcern from mongoengine import * -from mongoengine.session import _get_session -from mongoengine.synchronous.connection import get_db from mongoengine.context_managers import ( no_sub_classes, query_counter, @@ -21,8 +19,10 @@ switch_db, ) from mongoengine.pymongo_support import count_documents +from mongoengine.session import _get_session +from mongoengine.synchronous.connection import get_db from tests.synchronous.utils import MongoDBTestCase -from tests.utils import requires_mongodb_gte_44, MONGO_TEST_DB +from tests.utils import MONGO_TEST_DB, requires_mongodb_gte_44 class TestRollbackError(Exception): diff --git a/tests/synchronous/test_dereference.py b/tests/synchronous/test_dereference.py index 6ff87af10..49eaff83a 100644 --- a/tests/synchronous/test_dereference.py +++ b/tests/synchronous/test_dereference.py @@ -4,18 +4,11 @@ from mongoengine import * from mongoengine.context_managers import query_counter +from tests.synchronous.utils import MongoDBTestCase from tests.utils import MONGO_TEST_DB -class FieldTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.db = connect(db=MONGO_TEST_DB) - - @classmethod - def tearDownClass(cls): - cls.db.drop_database(MONGO_TEST_DB) - +class FieldTest(MongoDBTestCase): def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced.""" diff --git a/tests/synchronous/test_replicaset_connection.py b/tests/synchronous/test_replicaset_connection.py index bd2e19034..56efcb54f 100644 --- a/tests/synchronous/test_replicaset_connection.py +++ b/tests/synchronous/test_replicaset_connection.py @@ -4,6 +4,7 @@ import mongoengine from mongoengine.synchronous.connection import ConnectionFailure +from tests.synchronous.utils import reset_connections from tests.utils import MONGO_TEST_DB CONN_CLASS = MongoClient @@ -12,14 +13,10 @@ class ConnectionTest(unittest.TestCase): def setUp(self): - mongoengine.synchronous.connection._connection_settings = {} - mongoengine.synchronous.connection._connections = {} - mongoengine.synchronous.connection._dbs = {} + reset_connections() def tearDown(self): - mongoengine.synchronous.connection._connection_settings = {} - mongoengine.synchronous.connection._connections = {} - mongoengine.synchronous.connection._dbs = {} + reset_connections() def test_replicaset_uri_passes_read_preference(self): """Requires a replica set called "rs" on port 27017""" @@ -33,8 +30,7 @@ def test_replicaset_uri_passes_read_preference(self): return if not isinstance(conn, CONN_CLASS): - # really??? - return + raise TypeError(f"Expected {CONN_CLASS}, got {type(conn)}") assert conn.read_preference == READ_PREF diff --git a/tests/synchronous/test_signals.py b/tests/synchronous/test_signals.py index d917626a4..64880705c 100644 --- a/tests/synchronous/test_signals.py +++ b/tests/synchronous/test_signals.py @@ -4,6 +4,7 @@ from mongoengine import signals from mongoengine.base import _DocumentRegistry from mongoengine.registry import _CollectionRegistry +from tests.synchronous.utils import reset_connections from tests.utils import MONGO_TEST_DB signal_output = [] @@ -23,6 +24,7 @@ def get_signal_output(fn, *args, **kwargs): return signal_output def setUp(self): + reset_connections() connect(db=MONGO_TEST_DB) class Author(Document): @@ -211,47 +213,50 @@ def post_bulk_insert(cls, sender, documents, **kwargs): signals.post_bulk_insert.connect(Post.post_bulk_insert, sender=Post) def tearDown(self): - signals.pre_init.disconnect(self.Author.pre_init) - signals.post_init.disconnect(self.Author.post_init) - signals.post_delete.disconnect(self.Author.post_delete) - signals.pre_delete.disconnect(self.Author.pre_delete) - signals.post_save.disconnect(self.Author.post_save) - signals.pre_save_post_validation.disconnect( - self.Author.pre_save_post_validation - ) - signals.pre_save.disconnect(self.Author.pre_save) - signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) - signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) - - signals.post_delete.disconnect(self.Another.post_delete) - signals.pre_delete.disconnect(self.Another.pre_delete) - - signals.post_save.disconnect(self.ExplicitId.post_save) - - signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert) - signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert) - - # Check that all our signals got disconnected properly. - post_signals = ( - len(signals.pre_init.receivers), - len(signals.post_init.receivers), - len(signals.pre_save.receivers), - len(signals.pre_save_post_validation.receivers), - len(signals.post_save.receivers), - len(signals.pre_delete.receivers), - len(signals.post_delete.receivers), - len(signals.pre_bulk_insert.receivers), - len(signals.post_bulk_insert.receivers), - ) - - self.ExplicitId.objects.delete() - - # Note that there is a chance that the following assert fails in case - # some receivers (eventually created in other tests) - # gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect) - assert self.pre_signals == post_signals - _DocumentRegistry.clear() - _CollectionRegistry.clear() + try: + signals.pre_init.disconnect(self.Author.pre_init) + signals.post_init.disconnect(self.Author.post_init) + signals.post_delete.disconnect(self.Author.post_delete) + signals.pre_delete.disconnect(self.Author.pre_delete) + signals.post_save.disconnect(self.Author.post_save) + signals.pre_save_post_validation.disconnect( + self.Author.pre_save_post_validation + ) + signals.pre_save.disconnect(self.Author.pre_save) + signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) + signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) + + signals.post_delete.disconnect(self.Another.post_delete) + signals.pre_delete.disconnect(self.Another.pre_delete) + + signals.post_save.disconnect(self.ExplicitId.post_save) + + signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert) + signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert) + + # Check that all our signals got disconnected properly. + post_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.pre_save_post_validation.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers), + len(signals.pre_bulk_insert.receivers), + len(signals.post_bulk_insert.receivers), + ) + + self.ExplicitId.objects.delete() + + # Note that there is a chance that the following assert fails in case + # some receivers (eventually created in other tests) + # gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect) + assert self.pre_signals == post_signals + finally: + reset_connections() + _DocumentRegistry.clear() + _CollectionRegistry.clear() def test_model_signals(self): """Model saves should throw some signals.""" diff --git a/tests/synchronous/utils.py b/tests/synchronous/utils.py index b0e8f3b85..2dc636f36 100644 --- a/tests/synchronous/utils.py +++ b/tests/synchronous/utils.py @@ -7,11 +7,10 @@ from mongoengine import connect from mongoengine.base import _DocumentRegistry +from mongoengine.context_managers import query_counter +from mongoengine.mongodb_support import async_get_mongodb_version, get_mongodb_version from mongoengine.registry import _CollectionRegistry from mongoengine.synchronous.connection import disconnect_all, get_db -from mongoengine.context_managers import query_counter -from mongoengine.mongodb_support import get_mongodb_version, async_get_mongodb_version - from tests.utils import MONGO_TEST_DB @@ -21,14 +20,29 @@ class MongoDBTestCase(unittest.TestCase): """ def setUp(self): - disconnect_all() + # 1. Clear out everything from previous runs + reset_connections() + _DocumentRegistry.clear() + _CollectionRegistry.clear() + + # 2. Establish the fresh connection self._connection = connect(db=MONGO_TEST_DB) self._connection.drop_database(MONGO_TEST_DB) self.db = get_db() def tearDown(self): - self._connection.drop_database(MONGO_TEST_DB) + # 1. Grab the connection safely (handles cases where setup failed mid-way) + conn = getattr(self, "_connection", None) + if conn: + conn.drop_database(MONGO_TEST_DB) + + # 2. Break references immediately so Python collects them + self._connection = None + self.db = None + + # 3. Purge registries and disconnect disconnect_all() + reset_connections() _DocumentRegistry.clear() _CollectionRegistry.clear() @@ -108,8 +122,8 @@ def get_ops(self): def reset_connections(): from mongoengine.synchronous.connection import ( - _connections, _connection_settings, + _connections, _dbs, ) diff --git a/tests/test_pipeline_builder.py b/tests/test_pipeline_builder.py index c74f5e7f0..58d81fad9 100644 --- a/tests/test_pipeline_builder.py +++ b/tests/test_pipeline_builder.py @@ -1,27 +1,22 @@ from mongoengine import ( + DictField, Document, EmbeddedDocument, EmbeddedDocumentField, EmbeddedDocumentListField, + GenericReferenceField, IntField, - StringField, - ReferenceField, ListField, - DictField, MapField, - GenericReferenceField, + ReferenceField, + StringField, ) -from mongoengine.base import _DocumentRegistry from mongoengine.base.queryset.pipeline_builder import PipelineBuilder from mongoengine.base.queryset.pipeline_builder.schema import Schema - from tests.asynchronous.utils import MongoDBAsyncTestCase class TestQuerysetPipelineBuilderStress(MongoDBAsyncTestCase): - def tearDown(self): - _DocumentRegistry.clear() - def test_reference_field_attribute_match(self): class Parent(Document): age = IntField(required=True) diff --git a/tests/utils.py b/tests/utils.py index 618162068..5e2471ae8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,7 @@ import pymongo import pytest -from mongoengine.mongodb_support import get_mongodb_version, async_get_mongodb_version +from mongoengine.mongodb_support import async_get_mongodb_version, get_mongodb_version PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) diff --git a/uv.lock b/uv.lock index 4da7d6305..c38cd1788 100644 --- a/uv.lock +++ b/uv.lock @@ -506,6 +506,7 @@ docs = [ test = [ { name = "blinker" }, { name = "coverage" }, + { name = "mongomock" }, { name = "pillow" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -534,6 +535,7 @@ docs = [ test = [ { name = "blinker", specifier = ">=1.9" }, { name = "coverage", specifier = ">=7.14" }, + { name = "mongomock", specifier = ">=4.3.0" }, { name = "pillow", specifier = ">=12.2" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-asyncio", specifier = ">=1.3" }, @@ -542,6 +544,20 @@ test = [ { name = "tox-uv", specifier = ">=1.35.2" }, ] +[[package]] +name = "mongomock" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "pytz" }, + { name = "sentinels" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/a4/4a560a9f2a0bec43d5f63104f55bc48666d619ca74825c8ae156b08547cf/mongomock-4.3.0.tar.gz", hash = "sha256:32667b79066fabc12d4f17f16a8fd7361b5f4435208b3ba32c226e52212a8c30", size = 135862, upload-time = "2024-11-16T11:23:25.957Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/4d/8bea712978e3aff017a2ab50f262c620e9239cc36f348aae45e48d6a4786/mongomock-4.3.0-py2.py3-none-any.whl", hash = "sha256:5ef86bd12fc8806c6e7af32f21266c61b6c4ba96096f85129852d1c4fec1327e", size = 64891, upload-time = "2024-11-16T11:23:24.748Z" }, +] + [[package]] name = "nodeenv" version = "1.10.0" @@ -844,6 +860,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/6f/a05a317a66fee0aad270011461f1a63a453ed12471249f172f7d2e2bc7b4/python_discovery-1.3.1-py3-none-any.whl", hash = "sha256:ed188687ebb3b82c01a17cd5ac62fc94d9f6487a7f1a0f9dfe89753fec91039c", size = 33185, upload-time = "2026-05-12T20:53:34.969Z" }, ] +[[package]] +name = "pytz" +version = "2026.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/46/dd499ec9038423421951e4fad73051febaa13d2df82b4064f87af8b8c0c3/pytz-2026.2.tar.gz", hash = "sha256:0e60b47b29f21574376f218fe21abc009894a2321ea16c6754f3cad6eb7cdd6a", size = 320861, upload-time = "2026-05-04T01:35:29.667Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/dd/96da98f892250475bdf2328112d7468abdd4acc7b902b6af23f4ed958ea0/pytz-2026.2-py2.py3-none-any.whl", hash = "sha256:04156e608bee23d3792fd45c94ae47fae1036688e75032eea2e3bf0323d1f126", size = 510141, upload-time = "2026-05-04T01:35:27.408Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3" @@ -971,6 +996,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/d5/bc97ff895ec35cf3925d4bd60f3b39d822f377a446906ec9bcc87405e59b/ruff-0.15.14-py3-none-win_arm64.whl", hash = "sha256:ff47b90a9ef6a40c9e2f3b479c1fb78531adf055b94c1eba0a7ba04b31951826", size = 11208607, upload-time = "2026-05-21T14:34:26.525Z" }, ] +[[package]] +name = "sentinels" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/9b/07195878aa25fe6ed209ec74bc55ae3e3d263b60a489c6e73fdca3c8fe05/sentinels-1.1.1.tar.gz", hash = "sha256:3c2f64f754187c19e0a1a029b148b74cf58dd12ec27b4e19c0e5d6e22b5a9a86", size = 4393, upload-time = "2025-08-12T07:57:50.26Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/65/dea992c6a97074f6d8ff9eab34741298cac2ce23e2b6c74fb7d08afdf85c/sentinels-1.1.1-py3-none-any.whl", hash = "sha256:835d3b28f3b47f5284afa4bf2db6e00f2dc5f80f9923d4b7e7aeeeccf6146a11", size = 3744, upload-time = "2025-08-12T07:57:48.858Z" }, +] + [[package]] name = "snowballstemmer" version = "3.0.1"