diff --git a/bigquery_magics/bigquery.py b/bigquery_magics/bigquery.py index daad88e..b8c8e13 100644 --- a/bigquery_magics/bigquery.py +++ b/bigquery_magics/bigquery.py @@ -117,7 +117,6 @@ import IPython # type: ignore from IPython.core import magic_arguments # type: ignore from IPython.core.getipython import get_ipython -from google.api_core import client_info from google.api_core.exceptions import NotFound from google.cloud import bigquery from google.cloud.bigquery import exceptions @@ -126,13 +125,12 @@ from google.cloud.bigquery.job import QueryJobConfig import pandas -from bigquery_magics import environment from bigquery_magics import line_arg_parser as lap import bigquery_magics._versions_helpers import bigquery_magics.config import bigquery_magics.graph_server as graph_server +from bigquery_magics import core import bigquery_magics.pyformat -import bigquery_magics.version try: from google.cloud import bigquery_storage # type: ignore @@ -147,24 +145,6 @@ context = bigquery_magics.config.context -def _get_user_agent(): - identities = [ - f"ipython-{IPython.__version__}", - f"bigquery-magics/{bigquery_magics.version.__version__}", - ] - - if environment.is_vscode(): - identities.append("vscode") - if environment.is_vscode_google_cloud_code_extension_installed(): - identities.append(environment.GOOGLE_CLOUD_CODE_EXTENSION_NAME) - elif environment.is_jupyter(): - identities.append("jupyter") - if environment.is_jupyter_bigquery_plugin_installed(): - identities.append(environment.BIGQUERY_JUPYTER_PLUGIN_NAME) - - return " ".join(identities) - - def _handle_error(error, destination_var=None): """Process a query execution error. @@ -565,23 +545,11 @@ def _query_with_pandas(query: str, params: List[Any], args: Any): def _create_clients(args: Any) -> Tuple[bigquery.Client, Any]: - bigquery_client_options = copy.deepcopy(context.bigquery_client_options) - if args.bigquery_api_endpoint: - if isinstance(bigquery_client_options, dict): - bigquery_client_options["api_endpoint"] = args.bigquery_api_endpoint - else: - bigquery_client_options.api_endpoint = args.bigquery_api_endpoint - - bq_client = bigquery.Client( - project=args.project or context.project, - credentials=context.credentials, - default_query_job_config=context.default_query_job_config, - client_info=client_info.ClientInfo(user_agent=_get_user_agent()), - client_options=bigquery_client_options, + bq_client = core.create_bq_client( + project=args.project, + bigquery_api_endpoint=args.bigquery_api_endpoint, location=args.location, ) - if context._connection: - bq_client._connection = context._connection # Check and instantiate bq storage client if args.use_bqstorage_api is not None: @@ -634,7 +602,7 @@ def _handle_result(result, args): def _colab_query_callback(query: str, params: str): return IPython.core.display.JSON( - graph_server.convert_graph_data(query_results=json.loads(params)) + graph_server.convert_graph_params(json.loads(params)) ) @@ -663,7 +631,11 @@ def _colab_node_expansion_callback(request: dict, params_str: str): singleton_server_thread: threading.Thread = None -def _add_graph_widget(query_result): +MAX_GRAPH_VISUALIZATION_SIZE = 2_000_000 +MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE = 100_000 + + +def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any): try: from spanner_graphs.graph_visualization import generate_visualization_html except ImportError as err: @@ -700,10 +672,36 @@ def _add_graph_widget(query_result): port = graph_server.graph_server.port # Create html to invoke the graph server + args_dict = { + "bigquery_api_endpoint": args.bigquery_api_endpoint, + "project": args.project, + "location": args.location, + } + + estimated_size = query_result.memory_usage(index=True, deep=True).sum() + if estimated_size > MAX_GRAPH_VISUALIZATION_SIZE: + IPython.display.display( + IPython.core.display.HTML( + "Error: The query result is too large for graph visualization." + ) + ) + return + + table_dict = { + "projectId": query_job.configuration.destination.project, + "datasetId": query_job.configuration.destination.dataset_id, + "tableId": query_job.configuration.destination.table_id, + } + + params_dict = {"destination_table": table_dict, "args": args_dict} + if estimated_size < MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE: + params_dict["query_result"] = json.loads(query_result.to_json()) + + params_str = json.dumps(params_dict) html_content = generate_visualization_html( query="placeholder query", port=port, - params=query_result.to_json().replace("\\", "\\\\").replace('"', '\\"'), + params=params_str.replace("\\", "\\\\").replace('"', '\\"'), ) html_content = html_content.replace( '"graph_visualization.Query"', '"bigquery.graph_visualization.Query"' @@ -819,7 +817,7 @@ def _make_bq_query( result = result.to_dataframe(**dataframe_kwargs) if args.graph and _supports_graph_widget(result): - _add_graph_widget(result) + _add_graph_widget(result, query_job, args) return _handle_result(result, args) @@ -913,7 +911,7 @@ def _make_bqstorage_client(client, client_options): return client._ensure_bqstorage_client( client_options=client_options, - client_info=gapic_client_info.ClientInfo(user_agent=_get_user_agent()), + client_info=gapic_client_info.ClientInfo(user_agent=core._get_user_agent()), ) diff --git a/bigquery_magics/core.py b/bigquery_magics/core.py new file mode 100644 index 0000000..4b7287e --- /dev/null +++ b/bigquery_magics/core.py @@ -0,0 +1,75 @@ +# Copyright 2026 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from google.api_core import client_info +from google.cloud import bigquery +import IPython # type: ignore + +from bigquery_magics import environment +import bigquery_magics.config +import bigquery_magics.version + +context = bigquery_magics.config.context + + +def _get_user_agent(): + identities = [ + f"ipython-{IPython.__version__}", + f"bigquery-magics/{bigquery_magics.version.__version__}", + ] + + if environment.is_vscode(): + identities.append("vscode") + if environment.is_vscode_google_cloud_code_extension_installed(): + identities.append(environment.GOOGLE_CLOUD_CODE_EXTENSION_NAME) + elif environment.is_jupyter(): + identities.append("jupyter") + if environment.is_jupyter_bigquery_plugin_installed(): + identities.append(environment.BIGQUERY_JUPYTER_PLUGIN_NAME) + + return " ".join(identities) + + +def create_bq_client(*, project: str, bigquery_api_endpoint: str, location: str): + """Creates a BigQuery client. + + Args: + project: Project to use for api calls, None to obtain the project from the context. + bigquery_api_endpoint: Bigquery client endpoint. + location: Cloud region to use for api calls. + + Returns: + google.cloud.bigquery.client.Client: The BigQuery client. + """ + bigquery_client_options = copy.deepcopy(context.bigquery_client_options) + if bigquery_api_endpoint: + if isinstance(bigquery_client_options, dict): + bigquery_client_options["api_endpoint"] = bigquery_api_endpoint + else: + bigquery_client_options.api_endpoint = bigquery_api_endpoint + + bq_client = bigquery.Client( + project=project or context.project, + credentials=context.credentials, + default_query_job_config=context.default_query_job_config, + client_info=client_info.ClientInfo(user_agent=_get_user_agent()), + client_options=bigquery_client_options, + location=location, + ) + if context._connection: + bq_client._connection = context._connection + + return bq_client diff --git a/bigquery_magics/graph_server.py b/bigquery_magics/graph_server.py index 2cf58bc..65c301c 100644 --- a/bigquery_magics/graph_server.py +++ b/bigquery_magics/graph_server.py @@ -19,6 +19,10 @@ import threading from typing import Any, Dict, List +from google.cloud import bigquery + +from bigquery_magics import core + def execute_node_expansion(params, request): return {"error": "Node expansion not yet implemented"} @@ -54,7 +58,7 @@ def _stringify_properties(d: Any) -> Any: return _stringify_value(d) -def convert_graph_data(query_results: Dict[str, Dict[str, str]]): +def _convert_graph_data(query_results: Dict[str, Dict[str, str]]): """ Converts graph data to the form expected by the visualization framework. @@ -143,6 +147,24 @@ def convert_graph_data(query_results: Dict[str, Dict[str, str]]): return {"error": getattr(e, "message", str(e))} +def convert_graph_params(params: Dict[str, Any]): + query_results = None + if "query_result" in params: + query_results = params["query_result"] + else: + bq_client = core.create_bq_client( + project=params["args"]["project"], + bigquery_api_endpoint=params["args"]["bigquery_api_endpoint"], + location=params["args"]["location"], + ) + + table_ref = bigquery.TableReference.from_api_repr(params["destination_table"]) + query_results = json.loads( + bq_client.list_rows(table_ref).to_dataframe().to_json() + ) + return _convert_graph_data(query_results=query_results) + + class GraphServer: """ Http server invoked by Javascript to obtain the query results for visualization. @@ -251,7 +273,9 @@ def handle_post_ping(self): def handle_post_query(self): data = self.parse_post_data() - response = convert_graph_data(query_results=json.loads(data["params"])) + params = json.loads(data["params"]) + + response = convert_graph_params(params) self.do_data_response(response) def handle_post_node_expansion(self): diff --git a/tests/unit/bigquery/test_bigquery.py b/tests/unit/bigquery/test_bigquery.py index 6393d6d..38bed90 100644 --- a/tests/unit/bigquery/test_bigquery.py +++ b/tests/unit/bigquery/test_bigquery.py @@ -563,6 +563,7 @@ def test_bigquery_graph_json_json_result(monkeypatch): # Set up the context with monkeypatch so that it's reset for subsequent # tests. monkeypatch.setattr(bigquery_magics.context, "_credentials", mock_credentials) + monkeypatch.setattr(bigquery_magics.context, "_project", PROJECT_ID) # Mock out the BigQuery Storage API. bqstorage_mock = mock.create_autospec(bigquery_storage.BigQueryReadClient) @@ -597,6 +598,9 @@ def test_bigquery_graph_json_json_result(monkeypatch): google.cloud.bigquery.job.QueryJob, instance=True ) query_job_mock.to_dataframe.return_value = result + query_job_mock.configuration.destination.project = PROJECT_ID + query_job_mock.configuration.destination.dataset_id = DATASET_ID + query_job_mock.configuration.destination.table_id = TABLE_ID with run_query_patch as run_query_mock, ( bqstorage_client_patch @@ -630,6 +634,7 @@ def test_bigquery_graph_json_result(monkeypatch): # Set up the context with monkeypatch so that it's reset for subsequent # tests. monkeypatch.setattr(bigquery_magics.context, "_credentials", mock_credentials) + monkeypatch.setattr(bigquery_magics.context, "_project", PROJECT_ID) # Mock out the BigQuery Storage API. bqstorage_mock = mock.create_autospec(bigquery_storage.BigQueryReadClient) @@ -661,6 +666,9 @@ def test_bigquery_graph_json_result(monkeypatch): google.cloud.bigquery.job.QueryJob, instance=True ) query_job_mock.to_dataframe.return_value = result + query_job_mock.configuration.destination.project = PROJECT_ID + query_job_mock.configuration.destination.dataset_id = DATASET_ID + query_job_mock.configuration.destination.table_id = TABLE_ID with run_query_patch as run_query_mock, ( bqstorage_client_patch @@ -690,6 +698,12 @@ def test_bigquery_graph_json_result(monkeypatch): "mUZpbkdyYXBoLlBlcnNvbgB4kQQ=" in html_content ) # identifier in 3rd row of query result + # Verify that args are present in the HTML. + assert '\\"args\\": {' in html_content + assert '\\"bigquery_api_endpoint\\": null' in html_content + assert '\\"project\\": null' in html_content + assert '\\"location\\": null' in html_content + # Make sure we can run a second graph query, after the graph server is already running. try: return_value = ip.run_cell_magic("bigquery", "--graph", sql) @@ -714,12 +728,198 @@ def test_bigquery_graph_json_result(monkeypatch): "mUZpbkdyYXBoLlBlcnNvbgB4kQQ=" in html_content ) # identifier in 3rd row of query result + # Verify that args are present in the HTML. + assert '\\"args\\": {' in html_content + assert '\\"bigquery_api_endpoint\\": null' in html_content + assert '\\"project\\": null' in html_content + assert '\\"location\\": null' in html_content + assert bqstorage_mock.called # BQ storage client was used assert isinstance(return_value, pandas.DataFrame) assert len(return_value) == len(result) # verify row count assert list(return_value) == list(result) # verify column names +@pytest.mark.skipif( + graph_visualization is None or bigquery_storage is None, + reason="Requires `spanner-graph-notebook` and `google-cloud-bigquery-storage`", +) +def test_bigquery_graph_size_exceeds_max(monkeypatch): + globalipapp.start_ipython() + ip = globalipapp.get_ipython() + ip.extension_manager.load_extension("bigquery_magics") + mock_credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + monkeypatch.setattr(bigquery_magics.context, "_credentials", mock_credentials) + monkeypatch.setattr(bigquery_magics.context, "_project", PROJECT_ID) + + # Set threshold to a very small value to trigger the error. + monkeypatch.setattr(magics, "MAX_GRAPH_VISUALIZATION_SIZE", 5) + + bqstorage_mock = mock.create_autospec(bigquery_storage.BigQueryReadClient) + bqstorage_instance_mock = mock.create_autospec( + bigquery_storage.BigQueryReadClient, instance=True + ) + bqstorage_instance_mock._transport = mock.Mock() + bqstorage_mock.return_value = bqstorage_instance_mock + bqstorage_client_patch = mock.patch( + "google.cloud.bigquery_storage.BigQueryReadClient", bqstorage_mock + ) + + sql = "SELECT graph_json FROM t" + result = pandas.DataFrame(['{"id": 1}'], columns=["graph_json"]) + run_query_patch = mock.patch("bigquery_magics.bigquery._run_query", autospec=True) + display_patch = mock.patch("IPython.display.display", autospec=True) + query_job_mock = mock.create_autospec( + google.cloud.bigquery.job.QueryJob, instance=True + ) + query_job_mock.to_dataframe.return_value = result + query_job_mock.configuration.destination.project = PROJECT_ID + query_job_mock.configuration.destination.dataset_id = DATASET_ID + query_job_mock.configuration.destination.table_id = TABLE_ID + + with run_query_patch as run_query_mock, ( + bqstorage_client_patch + ), display_patch as display_mock: + run_query_mock.return_value = query_job_mock + + ip.run_cell_magic("bigquery", "--graph", sql) + + # Should display error message + assert display_mock.called + html_content = display_mock.call_args[0][0].data + assert ( + "Error: The query result is too large for graph visualization." + in html_content + ) + + +@pytest.mark.skipif( + graph_visualization is None or bigquery_storage is None, + reason="Requires `spanner-graph-notebook` and `google-cloud-bigquery-storage`", +) +def test_bigquery_graph_size_exceeds_query_result_max(monkeypatch): + globalipapp.start_ipython() + ip = globalipapp.get_ipython() + ip.extension_manager.load_extension("bigquery_magics") + mock_credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + monkeypatch.setattr(bigquery_magics.context, "_credentials", mock_credentials) + monkeypatch.setattr(bigquery_magics.context, "_project", PROJECT_ID) + + # Set threshold to a very small value, but larger than the other one. + # We want estimated_size > MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE + # and estimated_size <= MAX_GRAPH_VISUALIZATION_SIZE + monkeypatch.setattr(magics, "MAX_GRAPH_VISUALIZATION_SIZE", 1000000) + monkeypatch.setattr(magics, "MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE", 5) + + bqstorage_mock = mock.create_autospec(bigquery_storage.BigQueryReadClient) + bqstorage_instance_mock = mock.create_autospec( + bigquery_storage.BigQueryReadClient, instance=True + ) + bqstorage_instance_mock._transport = mock.Mock() + bqstorage_mock.return_value = bqstorage_instance_mock + bqstorage_client_patch = mock.patch( + "google.cloud.bigquery_storage.BigQueryReadClient", bqstorage_mock + ) + + sql = "SELECT graph_json FROM t" + result = pandas.DataFrame(['{"id": 1977323800}'], columns=["graph_json"]) + run_query_patch = mock.patch("bigquery_magics.bigquery._run_query", autospec=True) + display_patch = mock.patch("IPython.display.display", autospec=True) + query_job_mock = mock.create_autospec( + google.cloud.bigquery.job.QueryJob, instance=True + ) + query_job_mock.to_dataframe.return_value = result + query_job_mock.configuration.destination.project = PROJECT_ID + query_job_mock.configuration.destination.dataset_id = DATASET_ID + query_job_mock.configuration.destination.table_id = TABLE_ID + + with run_query_patch as run_query_mock, ( + bqstorage_client_patch + ), display_patch as display_mock: + run_query_mock.return_value = query_job_mock + + ip.run_cell_magic("bigquery", "--graph", sql) + + # Should display visualization but without query_result embedded. + assert display_mock.called + html_content = display_mock.call_args[0][0].data + assert "