diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 50eb72ffdb..50ad3e62b4 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -2653,7 +2653,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: async def _ensure_started(self, **kwargs) -> None: """Ensures that the plugin is started and initialized.""" - if os.getpid() != self._init_pid: + # _init_pid == 0 means the plugin was unpickled and has never been + # initialized in this process (the pickle sentinel set by + # __getstate__). Skip the fork reset in that case — no fork + # happened, and _started is already False so _lazy_setup will run. + # Real forks are caught by os.register_at_fork (line 108) and by + # this check when _init_pid is a real (non-zero) PID from a + # different process. + if self._init_pid != 0 and os.getpid() != self._init_pid: self._reset_runtime_state() if not self._started: # Kept original lock name as it was not explicitly changed. @@ -2665,6 +2672,12 @@ async def _ensure_started(self, **kwargs) -> None: await self._lazy_setup(**kwargs) self._started = True self._startup_error = None + # Record the current PID so fork detection works for + # the rest of this instance's lifetime. Without this, + # an unpickled plugin would keep _init_pid == 0 forever, + # disabling the PID-based fork check. + if self._init_pid == 0: + self._init_pid = os.getpid() except Exception as e: self._startup_error = e logger.error("Failed to initialize BigQuery Plugin: %s", e) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 8a05392bec..e2c5160468 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -7408,3 +7408,78 @@ async def test_view_error_still_logged( ) as plugin: await plugin._ensure_started() assert plugin._started + + +# ================================================================ +# TEST CLASS: Fork detection after pickle (Issue #86) +# ================================================================ +class TestForkDetectionAfterPickle: + """Tests that unpickled plugins do not false-positive fork detection.""" + + @pytest.mark.asyncio + async def test_no_reset_after_unpickle( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + """Unpickled plugin does not trigger _reset_runtime_state.""" + import pickle + + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + create_views=False, + ) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, DATASET_ID, table_id=TABLE_ID, config=config + ) + # Pickle round-trip simulates Vertex AI Agent Engine deployment + pickled = pickle.dumps(plugin) + unpickled = pickle.loads(pickled) + + assert unpickled._init_pid == 0 # pickle sentinel + + with mock.patch.object(unpickled, "_reset_runtime_state") as mock_reset: + await unpickled._ensure_started() + # Should NOT have called _reset_runtime_state because + # _init_pid == 0 means "unpickled, never initialized" + mock_reset.assert_not_called() + + assert unpickled._started + # After successful startup, _init_pid should be recorded so + # fork detection works for the rest of this instance's lifetime. + assert unpickled._init_pid == os.getpid() + await unpickled.shutdown() + + @pytest.mark.asyncio + async def test_reset_on_real_fork( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + """Plugin detects real fork when _init_pid is a real non-zero PID.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + create_views=False, + ) + async with managed_plugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) as plugin: + await plugin._ensure_started() + # Simulate a fork: set _init_pid to a different real PID + plugin._init_pid = max(os.getpid() - 1, 1) + plugin._started = True # pretend it was started in parent + + with mock.patch.object( + plugin, "_reset_runtime_state", wraps=plugin._reset_runtime_state + ) as mock_reset: + await plugin._ensure_started() + # Should have called _reset_runtime_state because + # _init_pid is a real PID different from os.getpid() + mock_reset.assert_called_once()