Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 76 additions & 42 deletions bigquery_magics/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
import IPython # type: ignore
from IPython.core import magic_arguments # type: ignore
from IPython.core.getipython import get_ipython
from google.api_core import client_info
from google.api_core.exceptions import NotFound
from google.cloud import bigquery
from google.cloud.bigquery import exceptions
Expand All @@ -126,13 +125,12 @@
from google.cloud.bigquery.job import QueryJobConfig
import pandas

from bigquery_magics import environment
from bigquery_magics import line_arg_parser as lap
import bigquery_magics._versions_helpers
import bigquery_magics.config
import bigquery_magics.graph_server as graph_server
from bigquery_magics import core
import bigquery_magics.pyformat
import bigquery_magics.version

try:
from google.cloud import bigquery_storage # type: ignore
Expand All @@ -147,24 +145,6 @@
context = bigquery_magics.config.context


def _get_user_agent():
identities = [
f"ipython-{IPython.__version__}",
f"bigquery-magics/{bigquery_magics.version.__version__}",
]

if environment.is_vscode():
identities.append("vscode")
if environment.is_vscode_google_cloud_code_extension_installed():
identities.append(environment.GOOGLE_CLOUD_CODE_EXTENSION_NAME)
elif environment.is_jupyter():
identities.append("jupyter")
if environment.is_jupyter_bigquery_plugin_installed():
identities.append(environment.BIGQUERY_JUPYTER_PLUGIN_NAME)

return " ".join(identities)


def _handle_error(error, destination_var=None):
"""Process a query execution error.

Expand Down Expand Up @@ -565,23 +545,9 @@ def _query_with_pandas(query: str, params: List[Any], args: Any):


def _create_clients(args: Any) -> Tuple[bigquery.Client, Any]:
bigquery_client_options = copy.deepcopy(context.bigquery_client_options)
if args.bigquery_api_endpoint:
if isinstance(bigquery_client_options, dict):
bigquery_client_options["api_endpoint"] = args.bigquery_api_endpoint
else:
bigquery_client_options.api_endpoint = args.bigquery_api_endpoint

bq_client = bigquery.Client(
project=args.project or context.project,
credentials=context.credentials,
default_query_job_config=context.default_query_job_config,
client_info=client_info.ClientInfo(user_agent=_get_user_agent()),
client_options=bigquery_client_options,
location=args.location,
bq_client = core.create_bq_client(
args.project, args.bigquery_api_endpoint, args.location
)
if context._connection:
bq_client._connection = context._connection

# Check and instantiate bq storage client
if args.use_bqstorage_api is not None:
Expand Down Expand Up @@ -634,7 +600,7 @@ def _handle_result(result, args):

def _colab_query_callback(query: str, params: str):
return IPython.core.display.JSON(
graph_server.convert_graph_data(query_results=json.loads(params))
graph_server.convert_graph_params(json.loads(params))
)


Expand Down Expand Up @@ -663,7 +629,49 @@ def _colab_node_expansion_callback(request: dict, params_str: str):
singleton_server_thread: threading.Thread = None


def _add_graph_widget(query_result):
MAX_GRAPH_VISUALIZATION_SIZE = 5000000
MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE = 100000


def _estimate_json_size(df: pandas.DataFrame) -> int:
"""Approximates the length of df.to_json(orient='records')
without materializing the string.
"""
num_rows, num_cols = df.shape
if num_rows == 0:
return 2 # "[]"

# 1. Key overhead: "column_name": (repeated for every row)
# Includes quotes, colon, and comma separator per field
key_overhead = sum(len(f'"{col}":') + 1 for col in df.columns) * num_rows

# 2. Row structural overhead: { } per row and [ ] for the list
# Plus commas between rows (num_rows - 1)
structural_overhead = (2 * num_rows) + 2 + (num_rows - 1)

# 3. Value lengths
total_val_len = 0
for col in df.columns:
series = df[col]

if pandas.api.types.is_bool_dtype(series):
# true (4) or false (5)
total_val_len += series.map({True: 4, False: 5}).sum()
elif pandas.api.types.is_numeric_dtype(series):
# Numeric values (no quotes). Sample for average length to save memory.
sample_size = min(len(series), 1000)
avg_len = series.sample(sample_size).astype(str).str.len().mean()
total_val_len += avg_len * num_rows
else:
# Strings/Objects: "value" + quotes (2) + rough escaping factor
# .str.len() is relatively memory-efficient
val_chars = series.astype(str).str.len().sum()
total_val_len += val_chars + (2 * num_rows)

return int(key_overhead + structural_overhead + total_val_len)


def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any):
try:
from spanner_graphs.graph_visualization import generate_visualization_html
except ImportError as err:
Expand Down Expand Up @@ -700,10 +708,36 @@ def _add_graph_widget(query_result):
port = graph_server.graph_server.port

# Create html to invoke the graph server
args_dict = {
"bigquery_api_endpoint": args.bigquery_api_endpoint,
"project": args.project,
"location": args.location,
}

estimated_size = _estimate_json_size(query_result)
if estimated_size > MAX_GRAPH_VISUALIZATION_SIZE:
IPython.display.display(
IPython.core.display.HTML(
"<big><b>Error:</b> The query result is too large for graph visualization.</big>"
)
)
return

table_dict = {
"projectId": query_job.configuration.destination.project,
"datasetId": query_job.configuration.destination.dataset_id,
"tableId": query_job.configuration.destination.table_id,
}

params_dict = {"destination_table": table_dict, "args": args_dict}
if estimated_size < MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE:
params_dict["query_result"] = json.loads(query_result.to_json())

params_str = json.dumps(params_dict)
html_content = generate_visualization_html(
query="placeholder query",
port=port,
params=query_result.to_json().replace("\\", "\\\\").replace('"', '\\"'),
params=params_str.replace("\\", "\\\\").replace('"', '\\"'),
)
html_content = html_content.replace(
'"graph_visualization.Query"', '"bigquery.graph_visualization.Query"'
Expand Down Expand Up @@ -819,7 +853,7 @@ def _make_bq_query(
result = result.to_dataframe(**dataframe_kwargs)

if args.graph and _supports_graph_widget(result):
_add_graph_widget(result)
_add_graph_widget(result, query_job, args)
return _handle_result(result, args)


Expand Down Expand Up @@ -913,7 +947,7 @@ def _make_bqstorage_client(client, client_options):

return client._ensure_bqstorage_client(
client_options=client_options,
client_info=gapic_client_info.ClientInfo(user_agent=_get_user_agent()),
client_info=gapic_client_info.ClientInfo(user_agent=core._get_user_agent()),
)


Expand Down
73 changes: 73 additions & 0 deletions bigquery_magics/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from google.api_core import client_info
from google.cloud import bigquery
import IPython # type: ignore
from bigquery_magics import environment
import bigquery_magics.config
import bigquery_magics.version

context = bigquery_magics.config.context


def _get_user_agent():
identities = [
f"ipython-{IPython.__version__}",
f"bigquery-magics/{bigquery_magics.version.__version__}",
]

if environment.is_vscode():
identities.append("vscode")
if environment.is_vscode_google_cloud_code_extension_installed():
identities.append(environment.GOOGLE_CLOUD_CODE_EXTENSION_NAME)
elif environment.is_jupyter():
identities.append("jupyter")
if environment.is_jupyter_bigquery_plugin_installed():
identities.append(environment.BIGQUERY_JUPYTER_PLUGIN_NAME)

return " ".join(identities)


def create_bq_client(project: str, bigquery_api_endpoint: str, location: str):
"""Creates a BigQuery client.

Args:
project: Project to use for api calls, None to obtain the project from the context.
bigquery_api_endpoint: Bigquery client endpoint.
location: Cloud region to use for api calls.

Returns:
google.cloud.bigquery.client.Client: The BigQuery client.
"""
bigquery_client_options = copy.deepcopy(context.bigquery_client_options)
if bigquery_api_endpoint:
if isinstance(bigquery_client_options, dict):
bigquery_client_options["api_endpoint"] = bigquery_api_endpoint
else:
bigquery_client_options.api_endpoint = bigquery_api_endpoint

bq_client = bigquery.Client(
project=project or context.project,
credentials=context.credentials,
default_query_job_config=context.default_query_job_config,
client_info=client_info.ClientInfo(user_agent=_get_user_agent()),
client_options=bigquery_client_options,
location=location,
)
if context._connection:
bq_client._connection = context._connection

return bq_client
28 changes: 26 additions & 2 deletions bigquery_magics/graph_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import threading
from typing import Any, Dict, List

from google.cloud import bigquery

from bigquery_magics import core


def execute_node_expansion(params, request):
return {"error": "Node expansion not yet implemented"}
Expand Down Expand Up @@ -54,7 +58,7 @@ def _stringify_properties(d: Any) -> Any:
return _stringify_value(d)


def convert_graph_data(query_results: Dict[str, Dict[str, str]]):
def _convert_graph_data(query_results: Dict[str, Dict[str, str]]):
"""
Converts graph data to the form expected by the visualization framework.

Expand Down Expand Up @@ -143,6 +147,24 @@ def convert_graph_data(query_results: Dict[str, Dict[str, str]]):
return {"error": getattr(e, "message", str(e))}


def convert_graph_params(params: Dict[str, Any]):
query_results = None
if "query_result" in params:
query_results = params["query_result"]
else:
bq_client = core.create_bq_client(
params["args"]["project"],
params["args"]["bigquery_api_endpoint"],
params["args"]["location"],
)

table_ref = bigquery.TableReference.from_api_repr(params["destination_table"])
query_results = json.loads(
bq_client.list_rows(table_ref).to_dataframe().to_json()
)
return _convert_graph_data(query_results=query_results)


class GraphServer:
"""
Http server invoked by Javascript to obtain the query results for visualization.
Expand Down Expand Up @@ -251,7 +273,9 @@ def handle_post_ping(self):

def handle_post_query(self):
data = self.parse_post_data()
response = convert_graph_data(query_results=json.loads(data["params"]))
params = json.loads(data["params"])

response = convert_graph_params(params)
self.do_data_response(response)

def handle_post_node_expansion(self):
Expand Down
Loading
Loading