diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py index a6d1ad2a78..ec834da52d 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py @@ -29,6 +29,7 @@ from google.adk.sessions import _session_util from google.adk.sessions.migration import _schema_check_utils from google.adk.sessions.schemas import v1 +from google.adk.sessions.schemas._safe_unpickle import safe_loads as _safe_pickle_loads from google.genai import types import sqlalchemy from sqlalchemy import create_engine @@ -59,7 +60,7 @@ def _row_to_event(row: dict) -> Event: if actions_val is not None: try: if isinstance(actions_val, bytes): - actions = pickle.loads(actions_val) + actions = _safe_pickle_loads(actions_val) else: # for spanner - it might return object directly actions = actions_val except Exception as e: diff --git a/src/google/adk/sessions/schemas/_safe_unpickle.py b/src/google/adk/sessions/schemas/_safe_unpickle.py new file mode 100644 index 0000000000..a61fcfb5b6 --- /dev/null +++ b/src/google/adk/sessions/schemas/_safe_unpickle.py @@ -0,0 +1,101 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Restricted unpickler for safe deserialization of v0 EventActions data. + +The v0 schema stored EventActions as pickle blobs. This module provides a +safe deserialization path that only allows known ADK and standard types, +blocking arbitrary code execution via crafted pickle payloads. + +See: https://docs.python.org/3/library/pickle.html#restricting-globals +""" + +from __future__ import annotations + +import io +import logging +import os +import pickle +from typing import Any + +logger = logging.getLogger("google_adk." + __name__) + +_ALLOWED_MODULE_PREFIXES: tuple[str, ...] = ( + "google.adk.", + "google.genai.", + "pydantic.", + "pydantic_core.", +) + +_ALLOWED_GLOBALS: dict[str, set[str]] = { + "builtins": { + "dict", + "list", + "set", + "tuple", + "frozenset", + "bytes", + "bytearray", + "True", + "False", + "None", + "type", + "object", + "complex", + "slice", + "range", + "int", + "float", + "str", + "bool", + }, + "collections": {"OrderedDict", "defaultdict"}, + "datetime": {"datetime", "date", "time", "timedelta", "timezone"}, + "copy_reg": {"_reconstructor"}, + "copyreg": {"_reconstructor", "__newobj__"}, + "_codecs": {"encode"}, + "enum": {"__new__", "Enum", "IntEnum", "StrEnum"}, +} + + +class _RestrictedUnpickler(pickle.Unpickler): + """Unpickler that only allows reconstruction of known-safe types.""" + + def find_class(self, module: str, name: str) -> Any: + for prefix in _ALLOWED_MODULE_PREFIXES: + if module.startswith(prefix): + return super().find_class(module, name) + allowed_names = _ALLOWED_GLOBALS.get(module) + if allowed_names and name in allowed_names: + return super().find_class(module, name) + raise pickle.UnpicklingError( + f"Blocked unsafe pickle global: {module}.{name}. " + "If this is a legitimate ADK type, please file an issue at " + "https://github.com/google/adk-python/issues" + ) + + +def safe_loads(data: bytes) -> Any: + """Deserialize pickle bytes using a restricted unpickler. + + If ADK_ALLOW_UNSAFE_V0_PICKLE=1 is set, falls back to unrestricted + pickle.loads() for compatibility. A deprecation warning is logged. + """ + if os.environ.get("ADK_ALLOW_UNSAFE_V0_PICKLE") == "1": + logger.warning( + "ADK_ALLOW_UNSAFE_V0_PICKLE is set - using unrestricted " + "pickle.loads(). This is unsafe and will be removed in a " + "future release. Migrate to the v1 JSON schema." + ) + return pickle.loads(data) # noqa: S301 + return _RestrictedUnpickler(io.BytesIO(data)).load() diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index e4a4368c6d..486791b790 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -57,6 +57,7 @@ from ...events.event import Event from ...events.event_actions import EventActions from ..session import Session +from ._safe_unpickle import safe_loads as _safe_pickle_loads from .shared import DEFAULT_MAX_KEY_LENGTH from .shared import DEFAULT_MAX_VARCHAR_LENGTH from .shared import DynamicJSON @@ -110,11 +111,24 @@ def process_bind_param(self, value, dialect): return pickle.dumps(value) return value + def result_processor(self, dialect, coltype): + if dialect.name in ("mysql", "spanner+spanner"): + return super().result_processor(dialect, coltype) + + def process(value): + if value is None: + return None + if isinstance(value, memoryview): + value = bytes(value) + return _safe_pickle_loads(value) + + return process + def process_result_value(self, value, dialect): """Ensures the raw bytes from the database are unpickled back into a Python object.""" if value is not None: if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) + return _safe_pickle_loads(value) return value diff --git a/tests/unittests/sessions/test_safe_unpickle.py b/tests/unittests/sessions/test_safe_unpickle.py new file mode 100644 index 0000000000..e28c98ed12 --- /dev/null +++ b/tests/unittests/sessions/test_safe_unpickle.py @@ -0,0 +1,139 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _safe_unpickle RestrictedUnpickler.""" + +from __future__ import annotations + +import io +import os +import pickle +import struct +import unittest + +from google.adk.sessions.schemas._safe_unpickle import safe_loads + + +def _make_global_payload(module: str, func: str, *args: str) -> bytes: + """Craft a pickle stream that calls module.func(*args).""" + buf = io.BytesIO() + buf.write(pickle.PROTO + struct.pack("B", 2)) + buf.write(b"c" + f"{module}\n{func}\n".encode()) + buf.write(b"(") + for arg in args: + encoded = arg.encode("utf-8") + buf.write( + pickle.SHORT_BINUNICODE + struct.pack(" safe_loads.""" + + def _round_trip(self, obj): + return safe_loads(pickle.dumps(obj)) + + def test_string_values(self): + original = {"state_delta": {"key": "value"}, "artifact_delta": {}} + self.assertEqual(self._round_trip(original), original) + + def test_nested_dict(self): + original = { + "state_delta": { + "user_prefs": {"theme": "dark", "lang": "en"}, + "counter": 42, + }, + "artifact_delta": {"files": ["a.txt", "b.txt"]}, + } + self.assertEqual(self._round_trip(original), original) + + def test_none_and_bool(self): + original = { + "skip_summarization": True, + "requested_auth_configs": None, + "escalate": False, + } + self.assertEqual(self._round_trip(original), original) + + def test_empty_dict(self): + self.assertEqual(self._round_trip({}), {}) + + +class TestEnvVarFallback(unittest.TestCase): + """ADK_ALLOW_UNSAFE_V0_PICKLE=1 must bypass RestrictedUnpickler.""" + + _ENV_KEY = "ADK_ALLOW_UNSAFE_V0_PICKLE" + _PAYLOAD = _make_global_payload("collections", "Counter") + + def test_blocked_without_env_var(self): + old = os.environ.pop(self._ENV_KEY, None) + try: + with self.assertRaises(pickle.UnpicklingError): + safe_loads(self._PAYLOAD) + finally: + if old is not None: + os.environ[self._ENV_KEY] = old + + def test_allowed_with_env_var(self): + old = os.environ.get(self._ENV_KEY) + try: + os.environ[self._ENV_KEY] = "1" + from collections import Counter + result = safe_loads(self._PAYLOAD) + self.assertIsInstance(result, Counter) + finally: + if old is None: + os.environ.pop(self._ENV_KEY, None) + else: + os.environ[self._ENV_KEY] = old + + +if __name__ == "__main__": + unittest.main()