diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index d033f1f234..8849e7715a 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -17,7 +17,6 @@ from contextlib import asynccontextmanager import copy from datetime import datetime -from datetime import timezone import logging from typing import Any from typing import AsyncIterator @@ -59,6 +58,7 @@ from .schemas.v1 import StorageMetadata from .schemas.v1 import StorageSession as StorageSessionV1 from .schemas.v1 import StorageUserState as StorageUserStateV1 +from .schemas.shared import update_time_from_timestamp from .session import Session from .state import State @@ -458,11 +458,10 @@ async def create_session( storage_user_state.state = storage_user_state.state | user_state_delta # Store the session - now = datetime.fromtimestamp(platform_time.get_time(), tz=timezone.utc) - is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT - is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT - if is_sqlite or is_postgresql: - now = now.replace(tzinfo=None) + dialect_name = self.db_engine.dialect.name + now = update_time_from_timestamp( + platform_time.get_time(), dialect_name + ) storage_session = schema.StorageSession( app_name=app_name, @@ -480,7 +479,7 @@ async def create_session( storage_app_state.state, storage_user_state.state, session_state ) session = storage_session.to_session( - state=merged_state, is_sqlite=is_sqlite + state=merged_state, dialect_name=dialect_name ) return session @@ -498,6 +497,7 @@ async def get_session( # 2. Get all the events based on session id and filtering config # 3. Convert and return the session schema = self._get_schema_classes() + dialect_name = self.db_engine.dialect.name async with self._rollback_on_exception_session( read_only=True ) as sql_session: @@ -543,9 +543,10 @@ async def get_session( # Convert storage session to session events = [e.to_event() for e in reversed(storage_events)] - is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT session = storage_session.to_session( - state=merged_state, events=events, is_sqlite=is_sqlite + state=merged_state, + events=events, + dialect_name=dialect_name, ) return session @@ -591,13 +592,16 @@ async def list_sessions( user_states_map[storage_user_state.user_id] = storage_user_state.state sessions = [] - is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT + dialect_name = self.db_engine.dialect.name for storage_session in results: session_state = storage_session.state user_state = user_states_map.get(storage_session.user_id, {}) merged_state = _merge_state(app_state, user_state, session_state) sessions.append( - storage_session.to_session(state=merged_state, is_sqlite=is_sqlite) + storage_session.to_session( + state=merged_state, + dialect_name=dialect_name, + ) ) return ListSessionsResponse(sessions=sessions) @@ -632,7 +636,7 @@ async def append_event(self, session: Session, event: Event) -> Event: # 2. Update session attributes based on event config. # 3. Store the new event. schema = self._get_schema_classes() - is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT + dialect_name = self.db_engine.dialect.name use_row_level_locking = self._supports_row_level_locking() state_delta = ( @@ -662,7 +666,7 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_session = storage_session_result.scalars().one_or_none() if storage_session is None: raise ValueError(f"Session {session.id} not found.") - storage_update_time = storage_session.get_update_timestamp(is_sqlite) + storage_update_time = storage_session.get_update_timestamp(dialect_name) storage_update_marker = storage_session.get_update_marker() storage_app_state = await _select_required_state( @@ -728,20 +732,16 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_session.state | state_deltas["session"] ) - if is_sqlite: - update_time = datetime.fromtimestamp( - event.timestamp, timezone.utc - ).replace(tzinfo=None) - else: - update_time = datetime.fromtimestamp(event.timestamp) - storage_session.update_time = update_time + storage_session.update_time = update_time_from_timestamp( + event.timestamp, dialect_name + ) sql_session.add(schema.StorageEvent.from_event(session, event)) await sql_session.commit() # Update timestamp with commit time session.last_update_time = storage_session.get_update_timestamp( - is_sqlite + dialect_name ) session._storage_update_marker = storage_session.get_update_marker() diff --git a/src/google/adk/sessions/schemas/shared.py b/src/google/adk/sessions/schemas/shared.py index 25d4ea9e95..33e7c6fa3f 100644 --- a/src/google/adk/sessions/schemas/shared.py +++ b/src/google/adk/sessions/schemas/shared.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import annotations +from datetime import datetime +from datetime import timezone import json from sqlalchemy import Dialect @@ -25,6 +27,33 @@ DEFAULT_MAX_KEY_LENGTH = 128 DEFAULT_MAX_VARCHAR_LENGTH = 256 +# Dialects that store TIMESTAMP values as UTC-naive datetimes and therefore +# require us to reattach UTC tzinfo on read and strip it on write. +_NAIVE_UTC_DIALECTS = frozenset({"sqlite", "postgresql"}) + + +def update_timestamp_from_dt(dt: datetime, dialect_name: str) -> float: + """Converts a DB-returned datetime to a POSIX timestamp. + + SQLite and PostgreSQL store naive datetimes that represent UTC values. + All other dialects return timezone-aware datetimes directly. + """ + if dialect_name in _NAIVE_UTC_DIALECTS: + return dt.replace(tzinfo=timezone.utc).timestamp() + return dt.timestamp() + + +def update_time_from_timestamp(posix_ts: float, dialect_name: str) -> datetime: + """Converts a POSIX timestamp to the datetime format expected by the DB. + + SQLite and PostgreSQL require a UTC-naive datetime; every other dialect + accepts (and prefers) a UTC-aware datetime. + """ + dt = datetime.fromtimestamp(posix_ts, timezone.utc) + if dialect_name in _NAIVE_UTC_DIALECTS: + return dt.replace(tzinfo=None) + return dt + class DynamicJSON(TypeDecorator): """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases.""" diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index e4a4368c6d..6718083b02 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -61,6 +61,7 @@ from .shared import DEFAULT_MAX_VARCHAR_LENGTH from .shared import DynamicJSON from .shared import PreciseTimestamp +from .shared import update_timestamp_from_dt logger = logging.getLogger("google_adk." + __name__) @@ -167,21 +168,16 @@ def update_timestamp_tz(self) -> float: This is a compatibility alias for callers that used the pre-`main` API. """ sqlalchemy_session = inspect(self).session - is_sqlite = bool( - sqlalchemy_session - and sqlalchemy_session.bind - and sqlalchemy_session.bind.dialect.name == "sqlite" + dialect_name = ( + sqlalchemy_session.bind.dialect.name + if sqlalchemy_session and sqlalchemy_session.bind + else None ) - return self.get_update_timestamp(is_sqlite=is_sqlite) + return self.get_update_timestamp(dialect_name) - def get_update_timestamp(self, is_sqlite: bool) -> float: - """Returns the time zone aware update timestamp.""" - if is_sqlite: - # SQLite does not support timezone. SQLAlchemy returns a naive datetime - # object without timezone information. We need to convert it to UTC - # manually. - return self.update_time.replace(tzinfo=timezone.utc).timestamp() - return self.update_time.timestamp() + def get_update_timestamp(self, dialect_name: str | None) -> float: + """Returns the update timestamp as a POSIX timestamp.""" + return update_timestamp_from_dt(self.update_time, dialect_name or "") def get_update_marker(self) -> str: """Returns a stable revision marker for optimistic concurrency checks.""" @@ -194,7 +190,7 @@ def to_session( self, state: dict[str, Any] | None = None, events: list[Event] | None = None, - is_sqlite: bool = False, + dialect_name: str | None = None, ) -> Session: """Converts the storage session to a session object.""" if state is None: @@ -208,7 +204,7 @@ def to_session( id=self.id, state=state, events=events, - last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite), + last_update_time=self.get_update_timestamp(dialect_name), ) session._storage_update_marker = self.get_update_marker() return session diff --git a/src/google/adk/sessions/schemas/v1.py b/src/google/adk/sessions/schemas/v1.py index 12d8ee9061..6b0242cbd3 100644 --- a/src/google/adk/sessions/schemas/v1.py +++ b/src/google/adk/sessions/schemas/v1.py @@ -46,6 +46,7 @@ from .shared import DEFAULT_MAX_VARCHAR_LENGTH from .shared import DynamicJSON from .shared import PreciseTimestamp +from .shared import update_timestamp_from_dt class Base(DeclarativeBase): @@ -114,21 +115,16 @@ def update_timestamp_tz(self) -> float: This is a compatibility alias for callers that used the pre-`main` API. """ sqlalchemy_session = inspect(self).session - is_sqlite = bool( - sqlalchemy_session - and sqlalchemy_session.bind - and sqlalchemy_session.bind.dialect.name == "sqlite" + dialect_name = ( + sqlalchemy_session.bind.dialect.name + if sqlalchemy_session and sqlalchemy_session.bind + else None ) - return self.get_update_timestamp(is_sqlite=is_sqlite) + return self.get_update_timestamp(dialect_name) - def get_update_timestamp(self, is_sqlite: bool) -> float: - """Returns the time zone aware update timestamp.""" - if is_sqlite: - # SQLite does not support timezone. SQLAlchemy returns a naive datetime - # object without timezone information. We need to convert it to UTC - # manually. - return self.update_time.replace(tzinfo=timezone.utc).timestamp() - return self.update_time.timestamp() + def get_update_timestamp(self, dialect_name: str | None) -> float: + """Returns the update timestamp as a POSIX timestamp.""" + return update_timestamp_from_dt(self.update_time, dialect_name or "") def get_update_marker(self) -> str: """Returns a stable revision marker for optimistic concurrency checks.""" @@ -141,7 +137,7 @@ def to_session( self, state: dict[str, Any] | None = None, events: list[Event] | None = None, - is_sqlite: bool = False, + dialect_name: str | None = None, ) -> Session: """Converts the storage session to a session object.""" if state is None: @@ -155,7 +151,7 @@ def to_session( id=self.id, state=state, events=events, - last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite), + last_update_time=self.get_update_timestamp(dialect_name), ) session._storage_update_marker = self.get_update_marker() return session diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 2d7d89f15f..3e6f9504d7 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -29,6 +29,8 @@ from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.database_session_service import DatabaseSessionService from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.schemas.shared import update_time_from_timestamp +from google.adk.sessions.schemas.shared import update_timestamp_from_dt from google.adk.sessions.sqlite_session_service import SqliteSessionService from google.genai import types import pytest @@ -103,44 +105,63 @@ def fake_create_async_engine(_db_url: str, **kwargs): @pytest.mark.parametrize('dialect_name', ['sqlite', 'postgresql']) -def test_database_session_service_strips_timezone_for_dialect(dialect_name): - """Verifies that timezone-aware datetimes are converted to naive datetimes - for SQLite and PostgreSQL to avoid 'can't subtract offset-naive and - offset-aware datetimes' errors. - - PostgreSQL's default TIMESTAMP type is WITHOUT TIME ZONE, which cannot - accept timezone-aware datetime objects when using asyncpg. SQLite also - requires naive datetimes. - """ - # Simulate the logic in create_session - is_sqlite = dialect_name == 'sqlite' - is_postgres = dialect_name == 'postgresql' +def test_update_time_from_timestamp_strips_timezone_for_naive_utc_dialects( + dialect_name, +): + """update_time_from_timestamp returns a UTC-naive datetime for SQLite and + PostgreSQL, which store TIMESTAMP WITHOUT TIME ZONE values.""" + posix_ts = 1_700_000_000.0 + result = update_time_from_timestamp(posix_ts, dialect_name) + assert result.tzinfo is None + # Value must represent the correct UTC instant. + assert result == datetime.fromtimestamp(posix_ts, timezone.utc).replace( + tzinfo=None + ) + + +def test_update_time_from_timestamp_preserves_timezone_for_other_dialects(): + """update_time_from_timestamp returns a UTC-aware datetime for dialects + that support TIMESTAMP WITH TIME ZONE (e.g. MySQL).""" + posix_ts = 1_700_000_000.0 + result = update_time_from_timestamp(posix_ts, 'mysql') + assert result.tzinfo is not None + assert result == datetime.fromtimestamp(posix_ts, timezone.utc) - now = datetime.now(timezone.utc) - assert now.tzinfo is not None # Starts with timezone - if is_sqlite or is_postgres: - now = now.replace(tzinfo=None) +@pytest.mark.parametrize('dialect_name', ['sqlite', 'postgresql']) +def test_update_timestamp_from_dt_treats_naive_dt_as_utc_for_naive_utc_dialects( + dialect_name, +): + """update_timestamp_from_dt must reattach UTC tzinfo before computing the + POSIX timestamp for SQLite and PostgreSQL. + + This is the core of the bug fixed in commit 0e5790805a2f4d: + PostgreSQL returns a UTC-naive datetime, so calling .timestamp() directly + on a non-UTC host would interpret it as local time and produce a wrong + POSIX value. + """ + posix_ts = 1_700_000_000.0 + # Simulate a naive datetime as returned by PostgreSQL / SQLite. + naive_utc_dt = datetime.fromtimestamp(posix_ts, timezone.utc).replace( + tzinfo=None + ) + assert naive_utc_dt.tzinfo is None - # Both SQLite and PostgreSQL should have timezone stripped - assert now.tzinfo is None + result = update_timestamp_from_dt(naive_utc_dt, dialect_name) + assert result == posix_ts -def test_database_session_service_preserves_timezone_for_other_dialects(): - """Verifies that timezone info is preserved for dialects that support it.""" - # For dialects like MySQL with explicit timezone support, we don't strip - dialect_name = 'mysql' - is_sqlite = dialect_name == 'sqlite' - is_postgres = dialect_name == 'postgresql' - now = datetime.now(timezone.utc) - assert now.tzinfo is not None +def test_update_timestamp_from_dt_uses_tzinfo_for_aware_dialects(): + """update_timestamp_from_dt uses the datetime's own tzinfo for dialects + that return timezone-aware datetimes (e.g. MySQL).""" + posix_ts = 1_700_000_000.0 + aware_dt = datetime.fromtimestamp(posix_ts, timezone.utc) + assert aware_dt.tzinfo is not None - if is_sqlite or is_postgres: - now = now.replace(tzinfo=None) + result = update_timestamp_from_dt(aware_dt, 'mysql') - # MySQL should preserve timezone (if the column type supports it) - assert now.tzinfo is not None + assert result == posix_ts def test_database_session_service_respects_pool_pre_ping_override():