diff --git a/firebase-sample/app.py b/firebase-sample/app.py new file mode 100644 index 0000000..725c7ab --- /dev/null +++ b/firebase-sample/app.py @@ -0,0 +1,11 @@ +import googleclouddebugger +googleclouddebugger.enable(use_firebase= True) + +from flask import Flask + +app = Flask(__name__) + +@app.route("/") +def hello_world(): + return "

Hello World!

" + diff --git a/firebase-sample/build-and-run.sh b/firebase-sample/build-and-run.sh new file mode 100755 index 0000000..a0cc7b1 --- /dev/null +++ b/firebase-sample/build-and-run.sh @@ -0,0 +1,20 @@ +#!/bin/bash -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd "${SCRIPT_DIR}/.." + +cd src +./build.sh +cd .. + +python3 -m venv /tmp/cdbg-venv +source /tmp/cdbg-venv/bin/activate +pip3 install -r requirements.txt +pip3 install src/dist/* --force-reinstall + +cd firebase-sample +pip3 install -r requirements.txt +python3 -m flask run +cd .. + +deactivate diff --git a/firebase-sample/requirements.txt b/firebase-sample/requirements.txt new file mode 100644 index 0000000..7e10602 --- /dev/null +++ b/firebase-sample/requirements.txt @@ -0,0 +1 @@ +flask diff --git a/requirements.txt b/requirements.txt index 48ab4e6..14ad4db 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ google-auth-httplib2 google-api-python-client google-api-core +firebase_admin pyyaml diff --git a/requirements_dev.txt b/requirements_dev.txt index 14662f3..89aa308 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,3 +1,4 @@ -r requirements.txt absl-py pytest +requests-mock diff --git a/src/googleclouddebugger/__init__.py b/src/googleclouddebugger/__init__.py index 00cd217..259cd88 100644 --- a/src/googleclouddebugger/__init__.py +++ b/src/googleclouddebugger/__init__.py @@ -30,6 +30,7 @@ from . import collector from . import error_data_visibility_policy from . import gcp_hub_client +from . import firebase_client from . import glob_data_visibility_policy from . import yaml_data_visibility_config_reader from . import cdbg_native @@ -38,24 +39,38 @@ __version__ = version.__version__ _flags = None -_hub_client = None +_backend_client = None _breakpoints_manager = None def _StartDebugger(): """Configures and starts the debugger.""" - global _hub_client + global _backend_client global _breakpoints_manager cdbg_native.InitializeModule(_flags) cdbg_native.LogInfo( f'Initializing Cloud Debugger Python agent version: {__version__}') - _hub_client = gcp_hub_client.GcpHubClient() + use_firebase = _flags.get('use_firebase') + if use_firebase: + _backend_client = firebase_client.FirebaseClient() + _backend_client.SetupAuth( + _flags.get('project_id'), _flags.get('service_account_json_file'), + _flags.get('firebase_db_url')) + else: + _backend_client = gcp_hub_client.GcpHubClient() + _backend_client.SetupAuth( + _flags.get('project_id'), _flags.get('project_number'), + _flags.get('service_account_json_file')) + _backend_client.SetupCanaryMode( + _flags.get('breakpoint_enable_canary'), + _flags.get('breakpoint_allow_canary_override')) + visibility_policy = _GetVisibilityPolicy() _breakpoints_manager = breakpoints_manager.BreakpointsManager( - _hub_client, visibility_policy) + _backend_client, visibility_policy) # Set up loggers for logpoints. collector.SetLogger(logging.getLogger()) @@ -63,17 +78,12 @@ def _StartDebugger(): collector.CaptureCollector.pretty_printers.append( appengine_pretty_printers.PrettyPrinter) - _hub_client.on_active_breakpoints_changed = ( + _backend_client.on_active_breakpoints_changed = ( _breakpoints_manager.SetActiveBreakpoints) - _hub_client.on_idle = _breakpoints_manager.CheckBreakpointsExpiration - _hub_client.SetupAuth( - _flags.get('project_id'), _flags.get('project_number'), - _flags.get('service_account_json_file')) - _hub_client.SetupCanaryMode( - _flags.get('breakpoint_enable_canary'), - _flags.get('breakpoint_allow_canary_override')) - _hub_client.InitializeDebuggeeLabels(_flags) - _hub_client.Start() + _backend_client.on_idle = _breakpoints_manager.CheckBreakpointsExpiration + + _backend_client.InitializeDebuggeeLabels(_flags) + _backend_client.Start() def _GetVisibilityPolicy(): diff --git a/src/googleclouddebugger/firebase_client.py b/src/googleclouddebugger/firebase_client.py new file mode 100644 index 0000000..0da2a9e --- /dev/null +++ b/src/googleclouddebugger/firebase_client.py @@ -0,0 +1,569 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Communicates with Firebase RTDB backend.""" + +from collections import deque +import hashlib +import json +import os +import platform +import requests +import socket +import sys +import threading +import traceback + +import firebase_admin +import firebase_admin.credentials +import firebase_admin.db +import firebase_admin.exceptions + +from . import backoff +from . import cdbg_native as native +from . import labels +from . import uniquifier_computer +from . import application_info +from . import version +# This module catches all exception. This is safe because it runs in +# a daemon thread (so we are not blocking Ctrl+C). We need to catch all +# the exception because HTTP client is unpredictable as far as every +# exception it can throw. +# pylint: disable=broad-except + +# Set of all known debuggee labels (passed down as flags). The value of +# a map is optional environment variable that can be used to set the flag +# (flags still take precedence). +_DEBUGGEE_LABELS = { + labels.Debuggee.MODULE: [ + 'GAE_SERVICE', 'GAE_MODULE_NAME', 'K_SERVICE', 'FUNCTION_NAME' + ], + labels.Debuggee.VERSION: [ + 'GAE_VERSION', 'GAE_MODULE_VERSION', 'K_REVISION', + 'X_GOOGLE_FUNCTION_VERSION' + ], + labels.Debuggee.MINOR_VERSION: ['GAE_DEPLOYMENT_ID', 'GAE_MINOR_VERSION'] +} + +# Debuggee labels used to format debuggee description (ordered). The minor +# version is excluded for the sake of consistency with AppEngine UX. +_DESCRIPTION_LABELS = [ + labels.Debuggee.PROJECT_ID, labels.Debuggee.MODULE, labels.Debuggee.VERSION +] + +_METADATA_SERVER_URL = 'http://metadata.google.internal/computeMetadata/v1' + +_TRANSIENT_ERROR_CODES = ('UNKNOWN', 'INTERNAL', 'N/A', 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', + 'UNAUTHENTICATED', 'PERMISSION_DENIED') + + +class NoProjectIdError(Exception): + """Used to indicate the project id cannot be determined.""" + + +class FirebaseClient(object): + """Firebase RTDB Backend client. + + Registers the debuggee, subscribes for active breakpoints and sends breakpoint + updates to the backend. + + This class supports two types of authentication: application default + credentials or a manually provided JSON credentials file for a service + account. + + FirebaseClient creates a worker thread that communicates with the backend. The + thread can be stopped with a Stop function, but it is optional since the + worker thread is marked as daemon. + """ + + def __init__(self): + self.on_active_breakpoints_changed = lambda x: None + self.on_idle = lambda: None + self._debuggee_labels = {} + self._credentials = None + self._project_id = None + self._database_url = None + self._debuggee_id = None + self._canary_mode = None + self._breakpoints = {} + self._main_thread = None + self._transmission_thread = None + self._transmission_thread_startup_lock = threading.Lock() + self._transmission_queue = deque(maxlen=100) + self._new_updates = threading.Event() + self._breakpoint_subscription = None + + # Events for unit testing. + self.registration_complete = threading.Event() + self.subscription_complete = threading.Event() + + # + # Configuration options (constants only modified by unit test) + # + + # Delay before retrying failed request. + self.register_backoff = backoff.Backoff() # Register debuggee. + self.update_backoff = backoff.Backoff() # Update breakpoint. + + # Maximum number of times that the message is re-transmitted before it + # is assumed to be poisonous and discarded + self.max_transmit_attempts = 10 + + def InitializeDebuggeeLabels(self, flags): + """Initialize debuggee labels from environment variables and flags. + + The caller passes all the flags that the debuglet got. This function + will only use the flags used to label the debuggee. Flags take precedence + over environment variables. + + Debuggee description is formatted from available flags. + + Args: + flags: dictionary of debuglet command line flags. + """ + self._debuggee_labels = {} + + for (label, var_names) in _DEBUGGEE_LABELS.items(): + # var_names is a list of possible environment variables that may contain + # the label value. Find the first one that is set. + for name in var_names: + value = os.environ.get(name) + if value: + # Special case for module. We omit the "default" module + # to stay consistent with AppEngine. + if label == labels.Debuggee.MODULE and value == 'default': + break + self._debuggee_labels[label] = value + break + + # Special case when FUNCTION_NAME is set and X_GOOGLE_FUNCTION_VERSION + # isn't set. We set the version to 'unversioned' to be consistent with other + # agents. + # TODO: Stop assigning 'unversioned' to a GCF and find the + # actual version. + if ('FUNCTION_NAME' in os.environ and + labels.Debuggee.VERSION not in self._debuggee_labels): + self._debuggee_labels[labels.Debuggee.VERSION] = 'unversioned' + + if flags: + self._debuggee_labels.update({ + name: value + for (name, value) in flags.items() + if name in _DEBUGGEE_LABELS + }) + + self._debuggee_labels[labels.Debuggee.PROJECT_ID] = self._project_id + + platform_enum = application_info.GetPlatform() + self._debuggee_labels[labels.Debuggee.PLATFORM] = platform_enum.value + + if platform_enum == application_info.PlatformType.CLOUD_FUNCTION: + region = application_info.GetRegion() + if region: + self._debuggee_labels[labels.Debuggee.REGION] = region + + def SetupAuth(self, + project_id=None, + service_account_json_file=None, + database_url=None): + """Sets up authentication with Google APIs. + + This will use the credentials from service_account_json_file if provided, + falling back to application default credentials. + See https://cloud.google.com/docs/authentication/production. + + Args: + project_id: GCP project ID (e.g. myproject). If not provided, will attempt + to retrieve it from the credentials. + service_account_json_file: JSON file to use for credentials. If not + provided, will default to application default credentials. + database_url: Firebase realtime database URL to be used. If not + provided, will default to https://{project_id}-cdbg.firebaseio.com + Raises: + NoProjectIdError: If the project id cannot be determined. + """ + if service_account_json_file: + self._credentials = firebase_admin.credentials.Certificate( + service_account_json_file) + if not project_id: + with open(service_account_json_file, encoding='utf-8') as f: + project_id = json.load(f).get('project_id') + else: + if not project_id: + try: + r = requests.get( + f'{_METADATA_SERVER_URL}/project/project-id', + headers={'Metadata-Flavor': 'Google'}) + project_id = r.text + except requests.exceptions.RequestException: + native.LogInfo('Metadata server not available') + + if not project_id: + raise NoProjectIdError( + 'Unable to determine the project id from the API credentials. ' + 'Please specify the project id using the --project_id flag.') + + self._project_id = project_id + + if database_url: + self._database_url = database_url + else: + self._database_url = f'https://{self._project_id}-cdbg.firebaseio.com' + + def Start(self): + """Starts the worker thread.""" + self._shutdown = False + + # Spin up the main thread which will create the other necessary threads. + self._main_thread = threading.Thread(target=self._MainThreadProc) + self._main_thread.name = 'Cloud Debugger main worker thread' + self._main_thread.daemon = True + self._main_thread.start() + + def Stop(self): + """Signals the worker threads to shut down and waits until it exits.""" + self._shutdown = True + self._new_updates.set() # Wake up the transmission thread. + + if self._main_thread is not None: + self._main_thread.join() + self._main_thread = None + + if self._transmission_thread is not None: + self._transmission_thread.join() + self._transmission_thread = None + + if self._breakpoint_subscription is not None: + self._breakpoint_subscription.close() + self._breakpoint_subscription = None + + def EnqueueBreakpointUpdate(self, breakpoint_data): + """Asynchronously updates the specified breakpoint on the backend. + + This function returns immediately. The worker thread is actually doing + all the work. The worker thread is responsible to retry the transmission + in case of transient errors. + + The assumption is that the breakpoint is moving from Active to Final state. + + Args: + breakpoint: breakpoint in either final or non-final state. + """ + with self._transmission_thread_startup_lock: + if self._transmission_thread is None: + self._transmission_thread = threading.Thread( + target=self._TransmissionThreadProc) + self._transmission_thread.name = 'Cloud Debugger transmission thread' + self._transmission_thread.daemon = True + self._transmission_thread.start() + + self._transmission_queue.append((breakpoint_data, 0)) + self._new_updates.set() # Wake up the worker thread to send immediately. + + def _MainThreadProc(self): + """Entry point for the worker thread. + + This thread only serves to register and kick off the firebase subscription + which will run in its own thread. That thread will be owned by + self._breakpoint_subscription. + """ + # Note: if self._credentials is None, default app credentials will be used. + # TODO: Error handling. + firebase_admin.initialize_app(self._credentials, + {'databaseURL': self._database_url}) + + self._RegisterDebuggee() + self.registration_complete.set() + self._SubscribeToBreakpoints() + self.subscription_complete.set() + + def _TransmissionThreadProc(self): + """Entry point for the transmission worker thread.""" + + while not self._shutdown: + self._new_updates.clear() + + delay = self._TransmitBreakpointUpdates() + + self._new_updates.wait(delay) + + def _RegisterDebuggee(self): + """Single attempt to register the debuggee. + + If the registration succeeds, sets self._debuggee_id to the registered + debuggee ID. + + Args: + service: client to use for API calls + + Returns: + (registration_required, delay) tuple + """ + try: + debuggee = self._GetDebuggee() + self._debuggee_id = debuggee['id'] + + try: + debuggee_path = f'cdbg/debuggees/{self._debuggee_id}' + native.LogInfo( + f'registering at {self._database_url}, path: {debuggee_path}') + firebase_admin.db.reference(debuggee_path).set(debuggee) + native.LogInfo( + f'Debuggee registered successfully, ID: {self._debuggee_id}') + self.register_backoff.Succeeded() + return (False, 0) # Proceed immediately to subscribing to breakpoints. + except BaseException: + native.LogInfo(f'Failed to register debuggee: {traceback.format_exc()}') + except BaseException: + native.LogWarning('Debuggee information not available: ' + + traceback.format_exc()) + + return (True, self.register_backoff.Failed()) + + def _SubscribeToBreakpoints(self): + # Kill any previous subscriptions first. + if self._breakpoint_subscription is not None: + self._breakpoint_subscription.close() + self._breakpoint_subscription = None + + path = f'cdbg/breakpoints/{self._debuggee_id}/active' + native.LogInfo(f'Subscribing to breakpoint updates at {path}') + ref = firebase_admin.db.reference(path) + self._breakpoint_subscription = ref.listen(self._ActiveBreakpointCallback) + + def _ActiveBreakpointCallback(self, event): + if event.event_type == 'put': + if event.data is None: + # Either deleting a breakpoint or initializing with no breakpoints. + # Initializing with no breakpoints is a no-op. + # If deleting, event.path will be /{breakpointid} + if event.path != '/': + breakpoint_id = event.path[1:] + del self._breakpoints[breakpoint_id] + else: + if event.path == '/': + # New set of breakpoints. + self._breakpoints = {} + for (key, value) in event.data.items(): + self._AddBreakpoint(key, value) + else: + # New breakpoint. + breakpoint_id = event.path[1:] + self._AddBreakpoint(breakpoint_id, event.data) + + elif event.event_type == 'patch': + # New breakpoint or breakpoints. + for (key, value) in event.data.items(): + self._AddBreakpoint(key, value) + else: + native.LogWarning('Unexpected event from Firebase: ' + f'{event.event_type} {event.path} {event.data}') + return + + native.LogInfo(f'Breakpoints list changed, {len(self._breakpoints)} active') + self.on_active_breakpoints_changed(list(self._breakpoints.values())) + + def _AddBreakpoint(self, breakpoint_id, breakpoint_data): + breakpoint_data['id'] = breakpoint_id + self._breakpoints[breakpoint_id] = breakpoint_data + + def _TransmitBreakpointUpdates(self): + """Tries to send pending breakpoint updates to the backend. + + Sends all the pending breakpoint updates. In case of transient failures, + the breakpoint is inserted back to the top of the queue. Application + failures are not retried (for example updating breakpoint in a final + state). + + Each pending breakpoint maintains a retry counter. After repeated transient + failures the breakpoint is discarded and dropped from the queue. + + Args: + service: client to use for API calls + + Returns: + (reconnect, timeout) tuple. The first element ("reconnect") is set to + true on unexpected HTTP responses. The caller should discard the HTTP + connection and create a new one. The second element ("timeout") is + set to None if all pending breakpoints were sent successfully. Otherwise + returns time interval in seconds to stall before retrying. + """ + retry_list = [] + + # There is only one consumer, so two step pop is safe. + while self._transmission_queue: + breakpoint_data, retry_count = self._transmission_queue.popleft() + + bp_id = breakpoint_data['id'] + + try: + # Something has changed on the breakpoint. + # It should be going from active to final, but let's make sure. + if not breakpoint_data['isFinalState']: + raise BaseException( + f'Unexpected breakpoint update requested: {breakpoint_data}') + + # If action is missing, it should be set to 'CAPTURE' + is_logpoint = breakpoint_data.get('action') == 'LOG' + is_snapshot = not is_logpoint + if is_snapshot: + breakpoint_data['action'] = 'CAPTURE' + + # Set the completion time on the server side using a magic value. + breakpoint_data['finalTimeUnixMsec'] = {'.sv': 'timestamp'} + + # First, remove from the active breakpoints. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/active/{bp_id}') + bp_ref.delete() + + # Save snapshot data for snapshots only. + if is_snapshot: + # Note that there may not be snapshot data. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/snapshots/{bp_id}') + bp_ref.set(breakpoint_data) + + # Now strip potential snapshot data. + breakpoint_data.pop('evaluatedExpressions', None) + breakpoint_data.pop('stackFrames', None) + breakpoint_data.pop('variableTable', None) + + # Then add it to the list of final breakpoints. + bp_ref = firebase_admin.db.reference( + f'cdbg/breakpoints/{self._debuggee_id}/final/{bp_id}') + bp_ref.set(breakpoint_data) + + native.LogInfo(f'Breakpoint {bp_id} update transmitted successfully') + + except firebase_admin.exceptions.FirebaseError as err: + if err.code in _TRANSIENT_ERROR_CODES: + if retry_count < self.max_transmit_attempts - 1: + native.LogInfo(f'Failed to send breakpoint {bp_id} update: ' + f'{traceback.format_exc()}') + retry_list.append((breakpoint_data, retry_count + 1)) + else: + native.LogWarning( + f'Breakpoint {bp_id} retry count exceeded maximum') + else: + # This is very common if multiple instances are sending final update + # simultaneously. + native.LogInfo(f'{err}, breakpoint: {bp_id}') + except socket.error as err: + if retry_count < self.max_transmit_attempts - 1: + native.LogInfo(f'Socket error {err.errno} while sending breakpoint ' + f'{bp_id} update: {traceback.format_exc()}') + retry_list.append((breakpoint_data, retry_count + 1)) + else: + native.LogWarning(f'Breakpoint {bp_id} retry count exceeded maximum') + # Socket errors shouldn't persist like this; reconnect. + #reconnect = True + except BaseException: + native.LogWarning(f'Fatal error sending breakpoint {bp_id} update: ' + f'{traceback.format_exc()}') + + self._transmission_queue.extend(retry_list) + + if not self._transmission_queue: + self.update_backoff.Succeeded() + # Nothing to send, wait until next breakpoint update. + return None + else: + return self.update_backoff.Failed() + + def _GetDebuggee(self): + """Builds the debuggee structure.""" + major_version = version.__version__.split('.', maxsplit=1)[0] + python_version = ''.join(platform.python_version().split('.')[:2]) + agent_version = f'google.com/python{python_version}-gcp/v{major_version}' + + debuggee = { + 'description': self._GetDebuggeeDescription(), + 'labels': self._debuggee_labels, + 'agentVersion': agent_version, + } + + source_context = self._ReadAppJsonFile('source-context.json') + if source_context: + debuggee['sourceContexts'] = [source_context] + + debuggee['uniquifier'] = self._ComputeUniquifier(debuggee) + + debuggee['id'] = self._ComputeDebuggeeId(debuggee) + + return debuggee + + def _ComputeDebuggeeId(self, debuggee): + """Computes a debuggee ID. + + The debuggee ID has to be identical on all instances. Therefore the + ID should not include any random elements or elements that may be + different on different instances. + + Args: + debuggee: complete debuggee message (including uniquifier) + + Returns: + Debuggee ID meeting the criteria described above. + """ + fullhash = hashlib.sha1(json.dumps(debuggee, + sort_keys=True).encode()).hexdigest() + return f'd-{fullhash[:8]}' + + def _GetDebuggeeDescription(self): + """Formats debuggee description based on debuggee labels.""" + return '-'.join(self._debuggee_labels[label] + for label in _DESCRIPTION_LABELS + if label in self._debuggee_labels) + + def _ComputeUniquifier(self, debuggee): + """Computes debuggee uniquifier. + + The debuggee uniquifier has to be identical on all instances. Therefore the + uniquifier should not include any random numbers and should only be based + on inputs that are guaranteed to be the same on all instances. + + Args: + debuggee: complete debuggee message without the uniquifier + + Returns: + Hex string of SHA1 hash of project information, debuggee labels and + debuglet version. + """ + uniquifier = hashlib.sha1() + + # Compute hash of application files if we don't have source context. This + # way we can still distinguish between different deployments. + if ('minorversion' not in debuggee.get('labels', []) and + 'sourceContexts' not in debuggee): + uniquifier_computer.ComputeApplicationUniquifier(uniquifier) + + return uniquifier.hexdigest() + + def _ReadAppJsonFile(self, relative_path): + """Reads JSON file from an application directory. + + Args: + relative_path: file name relative to application root directory. + + Returns: + Parsed JSON data or None if the file does not exist, can't be read or + not a valid JSON file. + """ + try: + with open( + os.path.join(sys.path[0], relative_path), 'r', encoding='utf-8') as f: + return json.load(f) + except (IOError, ValueError): + return None diff --git a/src/googleclouddebugger/python_breakpoint.py b/src/googleclouddebugger/python_breakpoint.py index 7b61fd1..1339ebe 100644 --- a/src/googleclouddebugger/python_breakpoint.py +++ b/src/googleclouddebugger/python_breakpoint.py @@ -420,13 +420,13 @@ def _BreakpointEvent(self, event, frame): self._CompleteBreakpoint({'status': error_status}) return - collector = collector.CaptureCollector(self.definition, + capture_collector = collector.CaptureCollector(self.definition, self.data_visibility_policy) # TODO: This is a temporary try/except. All exceptions should be # caught inside Collect and converted into breakpoint error messages. try: - collector.Collect(frame) + capture_collector.Collect(frame) except BaseException as e: # pylint: disable=broad-except native.LogInfo('Internal error during data capture: %s' % repr(e)) error_status = { @@ -448,4 +448,4 @@ def _BreakpointEvent(self, event, frame): self._CompleteBreakpoint({'status': error_status}) return - self._CompleteBreakpoint(collector.breakpoint, is_incremental=False) + self._CompleteBreakpoint(capture_collector.breakpoint, is_incremental=False) diff --git a/src/setup.py b/src/setup.py index 64a5d46..40b0191 100644 --- a/src/setup.py +++ b/src/setup.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Python Cloud Debugger build and packaging script.""" from configparser import ConfigParser @@ -49,8 +48,7 @@ def ReadConfig(section, value, default): 'For more details please see ' 'https://github.com/GoogleCloudPlatform/cloud-debug-python\n') -lib_dirs = ReadConfig('build_ext', - 'library_dirs', +lib_dirs = ReadConfig('build_ext', 'library_dirs', sysconfig.get_config_var('LIBDIR')).split(':') extra_compile_args = ReadConfig('cc_options', 'extra_compile_args', '').split() extra_link_args = ReadConfig('cc_options', 'extra_link_args', '').split() @@ -65,9 +63,10 @@ def ReadConfig(section, value, default): assert len(static_libs) == len(deps), (static_libs, deps, lib_dirs) cvars = sysconfig.get_config_vars() -cvars['OPT'] = str.join(' ', RemovePrefixes( - cvars.get('OPT').split(), - ['-g', '-O', '-Wstrict-prototypes'])) +cvars['OPT'] = str.join( + ' ', + RemovePrefixes( + cvars.get('OPT').split(), ['-g', '-O', '-Wstrict-prototypes'])) # Determine the current version of the package. The easiest way would be to # import "googleclouddebugger" and read its __version__ attribute. diff --git a/tests/firebase_client_test.py b/tests/firebase_client_test.py new file mode 100644 index 0000000..c1690b2 --- /dev/null +++ b/tests/firebase_client_test.py @@ -0,0 +1,425 @@ +"""Unit tests for firebase_client module.""" + +import errno +import os +import socket +import sys +import tempfile +from unittest import mock +from unittest.mock import MagicMock +from unittest.mock import call +from unittest.mock import patch +import requests +import requests_mock + +from googleapiclient.errors import HttpError +from googleclouddebugger import version +from googleclouddebugger import firebase_client + +from absl.testing import absltest +from absl.testing import parameterized + +import firebase_admin.credentials + +TEST_PROJECT_ID = 'test-project-id' +METADATA_PROJECT_URL = ('http://metadata.google.internal/computeMetadata/' + 'v1/project/project-id') + + +class FakeEvent: + + def __init__(self, event_type, path, data): + self.event_type = event_type + self.path = path + self.data = data + + +class FakeReference: + + def __init__(self): + self.subscriber = None + + def listen(self, callback): + self.subscriber = callback + + def update(self, event_type, path, data): + if self.subscriber: + event = FakeEvent(event_type, path, data) + self.subscriber(event) + + +class FirebaseClientTest(parameterized.TestCase): + """Simulates service account authentication.""" + + def setUp(self): + version.__version__ = 'test' + + self._client = firebase_client.FirebaseClient() + + self.breakpoints_changed_count = 0 + self.breakpoints = {} + + def tearDown(self): + self._client.Stop() + + def testSetupAuthDefault(self): + # By default, we try getting the project id from the metadata server. + # Note that actual credentials are not fetched. + with requests_mock.Mocker() as m: + m.get(METADATA_PROJECT_URL, text=TEST_PROJECT_ID) + + self._client.SetupAuth() + + self.assertEqual(TEST_PROJECT_ID, self._client._project_id) + self.assertEqual(f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com', + self._client._database_url) + + def testSetupAuthOverrideProjectIdNumber(self): + # If a project id is provided, we use it. + project_id = 'project2' + self._client.SetupAuth(project_id=project_id) + + self.assertEqual(project_id, self._client._project_id) + self.assertEqual(f'https://{project_id}-cdbg.firebaseio.com', + self._client._database_url) + + def testSetupAuthServiceAccountJsonAuth(self): + # We'll load credentials from the provided file (mocked for simplicity) + with mock.patch.object(firebase_admin.credentials, + 'Certificate') as firebase_certificate: + json_file = tempfile.NamedTemporaryFile() + # And load the project id from the file as well. + with open(json_file.name, 'w', encoding='utf-8') as f: + f.write(f'{{"project_id": "{TEST_PROJECT_ID}"}}') + self._client.SetupAuth(service_account_json_file=json_file.name) + + firebase_certificate.assert_called_with(json_file.name) + self.assertEqual(TEST_PROJECT_ID, self._client._project_id) + + def testSetupAuthNoProjectId(self): + # There will be an exception raised if we try to contact the metadata + # server on a non-gcp machine. + with requests_mock.Mocker() as m: + m.get(METADATA_PROJECT_URL, exc=requests.exceptions.RequestException) + + with self.assertRaises(firebase_client.NoProjectIdError): + self._client.SetupAuth() + + @patch('firebase_admin.db.reference') + @patch('firebase_admin.initialize_app') + def testStart(self, mock_initialize_app, mock_db_ref): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + + mock_initialize_app.assert_called_with( + None, {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}) + self.assertEqual([ + call(f'cdbg/debuggees/{debuggee_id}'), + call(f'cdbg/breakpoints/{debuggee_id}/active') + ], mock_db_ref.call_args_list) + + # TODO: testStartRegisterRetry + # TODO: testStartSubscribeRetry + # - Note: failures don't require retrying registration. + + @patch('firebase_admin.db.reference') + @patch('firebase_admin.initialize_app') + def testBreakpointSubscription(self, mock_initialize_app, mock_db_ref): + mock_register_ref = MagicMock() + fake_subscribe_ref = FakeReference() + mock_db_ref.side_effect = [mock_register_ref, fake_subscribe_ref] + + # This class will keep track of the breakpoint updates and will check + # them against expectations. + class ResultChecker: + + def __init__(self, expected_results, test): + self._expected_results = expected_results + self._test = test + self._change_count = 0 + + def callback(self, new_breakpoints): + self._test.assertEqual(self._expected_results[self._change_count], + new_breakpoints) + self._change_count += 1 + + breakpoints = [ + { + 'id': 'breakpoint-0', + 'location': { + 'path': 'foo.py', + 'line': 18 + } + }, + { + 'id': 'breakpoint-1', + 'location': { + 'path': 'bar.py', + 'line': 23 + } + }, + { + 'id': 'breakpoint-2', + 'location': { + 'path': 'baz.py', + 'line': 45 + } + }, + ] + + expected_results = [[breakpoints[0]], [breakpoints[0], breakpoints[1]], + [breakpoints[0], breakpoints[1], breakpoints[2]], + [breakpoints[1], breakpoints[2]]] + result_checker = ResultChecker(expected_results, self) + + self._client.on_active_breakpoints_changed = result_checker.callback + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + # Send in updates to trigger the subscription callback. + fake_subscribe_ref.update('put', '/', + {breakpoints[0]['id']: breakpoints[0]}) + fake_subscribe_ref.update('patch', '/', + {breakpoints[1]['id']: breakpoints[1]}) + fake_subscribe_ref.update('put', f'/{breakpoints[2]["id"]}', breakpoints[2]) + fake_subscribe_ref.update('put', f'/{breakpoints[0]["id"]}', None) + + self.assertEqual(len(expected_results), result_checker._change_count) + + def _TestInitializeLabels(self, module_var, version_var, minor_var): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + self._client.InitializeDebuggeeLabels({ + 'module': 'my_module', + 'version': '1', + 'minorversion': '23', + 'something_else': 'irrelevant' + }) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'my_module', + 'version': '1', + 'minorversion': '23', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-my_module-1', + self._client._GetDebuggeeDescription()) + + uniquifier1 = self._client._ComputeUniquifier( + {'labels': self._client._debuggee_labels}) + self.assertTrue(uniquifier1) # Not empty string. + + try: + os.environ[module_var] = 'env_module' + os.environ[version_var] = '213' + os.environ[minor_var] = '3476734' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'env_module', + 'version': '213', + 'minorversion': '3476734', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-env_module-213', + self._client._GetDebuggeeDescription()) + + os.environ[module_var] = 'default' + os.environ[version_var] = '213' + os.environ[minor_var] = '3476734' + self._client.InitializeDebuggeeLabels({'minorversion': 'something else'}) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'version': '213', + 'minorversion': 'something else', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ[module_var] + del os.environ[version_var] + del os.environ[minor_var] + + def testInitializeLegacyDebuggeeLabels(self): + self._TestInitializeLabels('GAE_MODULE_NAME', 'GAE_MODULE_VERSION', + 'GAE_MINOR_VERSION') + + def testInitializeDebuggeeLabels(self): + self._TestInitializeLabels('GAE_SERVICE', 'GAE_VERSION', + 'GAE_DEPLOYMENT_ID') + + def testInitializeCloudRunDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['K_SERVICE'] = 'env_module' + os.environ['K_REVISION'] = '213' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'env_module', + 'version': '213', + 'platform': 'default' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-env_module-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['K_SERVICE'] + del os.environ['K_REVISION'] + + def testInitializeCloudFunctionDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + os.environ['X_GOOGLE_FUNCTION_VERSION'] = '213' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': '213', + 'platform': 'cloud_function' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-213', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + del os.environ['X_GOOGLE_FUNCTION_VERSION'] + + def testInitializeCloudFunctionUnversionedDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': 'unversioned', + 'platform': 'cloud_function' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-unversioned', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + + def testInitializeCloudFunctionWithRegionDebuggeeLabels(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + try: + os.environ['FUNCTION_NAME'] = 'fcn-name' + os.environ['FUNCTION_REGION'] = 'fcn-region' + self._client.InitializeDebuggeeLabels(None) + self.assertEqual( + { + 'projectid': 'test-project-id', + 'module': 'fcn-name', + 'version': 'unversioned', + 'platform': 'cloud_function', + 'region': 'fcn-region' + }, self._client._debuggee_labels) + self.assertEqual('test-project-id-fcn-name-unversioned', + self._client._GetDebuggeeDescription()) + + finally: + del os.environ['FUNCTION_NAME'] + del os.environ['FUNCTION_REGION'] + + def testAppFilesUniquifierNoMinorVersion(self): + """Verify that uniquifier_computer is used if minor version not defined.""" + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + sys.path.insert(0, root) + try: + uniquifier1 = self._client._ComputeUniquifier({}) + + with open(os.path.join(root, 'app.py'), 'w', encoding='utf-8') as f: + f.write('hello') + uniquifier2 = self._client._ComputeUniquifier({}) + finally: + del sys.path[0] + + self.assertNotEqual(uniquifier1, uniquifier2) + + def testAppFilesUniquifierWithMinorVersion(self): + """Verify that uniquifier_computer not used if minor version is defined.""" + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + + os.environ['GAE_MINOR_VERSION'] = '12345' + sys.path.insert(0, root) + try: + self._client.InitializeDebuggeeLabels(None) + + uniquifier1 = self._client._GetDebuggee()['uniquifier'] + + with open(os.path.join(root, 'app.py'), 'w', encoding='utf-8') as f: + f.write('hello') + uniquifier2 = self._client._GetDebuggee()['uniquifier'] + finally: + del os.environ['GAE_MINOR_VERSION'] + del sys.path[0] + + self.assertEqual(uniquifier1, uniquifier2) + + def testSourceContext(self): + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + + root = tempfile.mkdtemp('', 'fake_app_') + source_context_path = os.path.join(root, 'source-context.json') + + sys.path.insert(0, root) + try: + debuggee_no_source_context1 = self._client._GetDebuggee() + + with open(source_context_path, 'w', encoding='utf-8') as f: + f.write('not a valid JSON') + debuggee_bad_source_context = self._client._GetDebuggee() + + with open(os.path.join(root, 'fake_app.py'), 'w', encoding='utf-8') as f: + f.write('pretend') + debuggee_no_source_context2 = self._client._GetDebuggee() + + with open(source_context_path, 'w', encoding='utf-8') as f: + f.write('{"what": "source context"}') + debuggee_with_source_context = self._client._GetDebuggee() + + os.remove(source_context_path) + finally: + del sys.path[0] + + self.assertNotIn('sourceContexts', debuggee_no_source_context1) + self.assertNotIn('sourceContexts', debuggee_bad_source_context) + self.assertListEqual([{ + 'what': 'source context' + }], debuggee_with_source_context['sourceContexts']) + + uniquifiers = set() + uniquifiers.add(debuggee_no_source_context1['uniquifier']) + uniquifiers.add(debuggee_with_source_context['uniquifier']) + uniquifiers.add(debuggee_bad_source_context['uniquifier']) + self.assertLen(uniquifiers, 1) + uniquifiers.add(debuggee_no_source_context2['uniquifier']) + self.assertLen(uniquifiers, 2) + + +if __name__ == '__main__': + absltest.main()