Skip to content

Commit 8409da3

Browse files
committed
[SPARK-49382][PS] Make frame box plot properly render the fliers/outliers
### What changes were proposed in this pull request? fliers/outliers was ignored in the initial implementation #36317 ### Why are the changes needed? feature parity for Pandas and Series box plot ### Does this PR introduce _any_ user-facing change? ``` import pyspark.pandas as ps df = ps.DataFrame([[5.1, 3.5, 0], [4.9, 3.0, 0], [7.0, 3.2, 1], [6.4, 3.2, 1], [5.9, 3.0, 2], [100, 200, 300]], columns=['length', 'width', 'species']) df.boxplot() ``` `df.length.plot.box()` ![image](https://github.com/user-attachments/assets/43da563c-5f68-4305-ad27-a4f04815dfd1) before: `df.boxplot()` ![image](https://github.com/user-attachments/assets/e25c2760-c12a-4801-a730-3987a020f889) after: `df.boxplot()` ![image](https://github.com/user-attachments/assets/c19f13b1-b9e4-423e-bcec-0c47c1c8df32) ### How was this patch tested? CI and manually check ### Was this patch authored or co-authored using generative AI tooling? No Closes #47866 from zhengruifeng/plot_hist_fly. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent f596079 commit 8409da3

File tree

4 files changed

+57
-1
lines changed

4 files changed

+57
-1
lines changed

python/pyspark/pandas/plot/core.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from pyspark.sql import functions as F, Column
2828
from pyspark.sql.types import DoubleType
29+
from pyspark.pandas.spark import functions as SF
2930
from pyspark.pandas.missing import unsupported_function
3031
from pyspark.pandas.config import get_option
3132
from pyspark.pandas.utils import name_like_string
@@ -437,6 +438,37 @@ def get_fliers(colname, outliers, min_val):
437438

438439
return fliers
439440

441+
@staticmethod
442+
def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers):
443+
scols = []
444+
extract_colnames = []
445+
for i, colname in enumerate(colnames):
446+
formated_colname = "`{}`".format(colname)
447+
outlier_colname = "__{}_outlier".format(colname)
448+
min_val = multicol_whiskers[colname]["min"]
449+
pair_col = F.struct(
450+
F.abs(F.col(formated_colname) - F.lit(min_val)).alias("ord"),
451+
F.col(formated_colname).alias("val"),
452+
)
453+
scols.append(
454+
SF.collect_top_k(
455+
F.when(F.col(outlier_colname), pair_col)
456+
.otherwise(F.lit(None))
457+
.alias(f"pair_{i}"),
458+
1001,
459+
False,
460+
).alias(f"top_{i}")
461+
)
462+
extract_colnames.append(f"top_{i}.val")
463+
464+
results = multicol_outliers.select(scols).select(extract_colnames).first()
465+
466+
fliers = {}
467+
for i, colname in enumerate(colnames):
468+
fliers[colname] = results[i]
469+
470+
return fliers
471+
440472

441473
class KdePlotBase(NumericPlotBase):
442474
@staticmethod

python/pyspark/pandas/plot/plotly.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,19 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
199199
# Computes min and max values of non-outliers - the whiskers
200200
whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names, outliers)
201201

202+
fliers = None
203+
if boxpoints:
204+
fliers = BoxPlotBase.get_multicol_fliers(numeric_column_names, outliers, whiskers)
205+
202206
i = 0
203207
for colname in numeric_column_names:
204208
col_stats = multicol_stats[colname]
205209
col_whiskers = whiskers[colname]
206210

211+
col_fliers = None
212+
if fliers is not None and colname in fliers and len(fliers[colname]) > 0:
213+
col_fliers = [fliers[colname]]
214+
207215
fig.add_trace(
208216
go.Box(
209217
x=[i],
@@ -214,7 +222,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
214222
mean=[col_stats["mean"]],
215223
lowerfence=[col_whiskers["min"]],
216224
upperfence=[col_whiskers["max"]],
217-
y=None, # todo: support y=fliers
225+
y=col_fliers,
218226
boxpoints=boxpoints,
219227
notched=notched,
220228
**kwargs,

python/pyspark/pandas/spark/functions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,19 @@ def null_index(col: Column) -> Column:
174174
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
175175

176176

177+
def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
178+
if is_remote():
179+
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
180+
181+
return _invoke_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse))
182+
183+
else:
184+
from pyspark import SparkContext
185+
186+
sc = SparkContext._active_spark_context
187+
return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse))
188+
189+
177190
def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
178191
unit_mapping = {
179192
"YEAR": "years",

sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ private[sql] object PythonSQLUtils extends Logging {
149149

150150
def nullIndex(e: Column): Column = Column.internalFn("null_index", e)
151151

152+
def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
153+
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))
154+
152155
def pandasProduct(e: Column, ignoreNA: Boolean): Column =
153156
Column.internalFn("pandas_product", e, lit(ignoreNA))
154157

0 commit comments

Comments
 (0)