Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def join(
ridx = ensure_platform_int(ridx)
return join_index, lidx, ridx

return cast(F, join)
return cast("F", join)


def _new_Index(cls, d):
Expand Down Expand Up @@ -859,7 +859,7 @@ def _engine(
elif self._engine_type is libindex.ObjectEngine:
return libindex.ExtensionEngine(target_values)

target_values = cast(np.ndarray, target_values)
target_values = cast("np.ndarray", target_values)
# to avoid a reference cycle, bind `target_values` to a local variable, so
# `self` is not passed into the lambda.
if target_values.dtype == bool:
Expand Down Expand Up @@ -1469,7 +1469,7 @@ def _get_level_names(self) -> range | Sequence[Hashable]:
def _mpl_repr(self) -> np.ndarray:
# how to represent ourselves to matplotlib
if isinstance(self.dtype, np.dtype) and self.dtype.kind != "M":
return cast(np.ndarray, self.values)
return cast("np.ndarray", self.values)
return self.astype(object, copy=False)._values

_default_na_rep = "NaN"
Expand Down Expand Up @@ -4451,7 +4451,7 @@ def _join_empty(
ridx: np.ndarray | None

if len(other):
how = cast(JoinHow, {"left": "right", "right": "left"}.get(how, how))
how = cast("JoinHow", {"left": "right", "right": "left"}.get(how, how))
join_index, ridx, lidx = other._join_empty(self, how, sort)
elif how in ["left", "outer"]:
if sort and not self.is_monotonic_increasing:
Expand Down Expand Up @@ -4730,7 +4730,7 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]:

if keep_order: # just drop missing values. o.w. keep order
left_indexer = np.arange(len(left), dtype=np.intp)
left_indexer = cast(np.ndarray, left_indexer)
left_indexer = cast("np.ndarray", left_indexer)
mask = new_lev_codes != -1
if not mask.all():
new_codes = [lab[mask] for lab in new_codes]
Expand Down Expand Up @@ -5486,7 +5486,7 @@ def equals(self, other: Any) -> bool:
if not isinstance(other, type(self)):
return False

earr = cast(ExtensionArray, self._data)
earr = cast("ExtensionArray", self._data)
return earr.equals(other._data)

if isinstance(other.dtype, ExtensionDtype):
Expand Down Expand Up @@ -5787,7 +5787,7 @@ def sort_values(
items=self, ascending=ascending, na_position=na_position, key=key
)
else:
idx = cast(Index, ensure_key_mapped(self, key))
idx = cast("Index", ensure_key_mapped(self, key))
_as = idx.argsort(na_position=na_position)
if not ascending:
_as = _as[::-1]
Expand Down Expand Up @@ -7576,7 +7576,7 @@ def ensure_index(index_like: Axes, copy: bool = False) -> Index:
# check in clean_index_list
index_like = list(index_like)

if len(index_like) and lib.is_all_arraylike(index_like):
if index_like and lib.is_all_arraylike(index_like):
from pandas.core.indexes.multi import MultiIndex

return MultiIndex.from_arrays(index_like)
Expand Down
42 changes: 25 additions & 17 deletions pandas/io/formats/style_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ class StylerRenderer:
Base class to process rendering a Styler with a specified jinja2 template.
"""

loader = jinja2.PackageLoader("pandas", "io/formats/templates")
import os

loader = jinja2.FileSystemLoader(
os.path.join(os.path.dirname(__file__), "templates")
)
env = jinja2.Environment(loader=loader, trim_blocks=True)
template_html = env.get_template("html.tpl")
template_html_table = env.get_template("html_table.tpl")
Expand Down Expand Up @@ -834,10 +838,7 @@ def _generate_body_row(

data_element = _element(
"td",
(
f"{self.css['data']} {self.css['row']}{r} "
f"{self.css['col']}{c}{cls}"
),
(f"{self.css['data']} {self.css['row']}{r} {self.css['col']}{c}{cls}"),
value,
data_element_visible,
attributes="",
Expand Down Expand Up @@ -956,7 +957,7 @@ def concatenated_visible_rows(obj):
idx_len = d["index_lengths"].get((lvl, r), None)
if idx_len is not None: # i.e. not a sparsified entry
d["clines"][rn + idx_len].append(
f"\\cline{{{lvln+1}-{len(visible_index_levels)+data_len}}}"
f"\\cline{{{lvln + 1}-{len(visible_index_levels) + data_len}}}"
)

def format(
Expand Down Expand Up @@ -1211,7 +1212,7 @@ def format(
data = self.data.loc[subset]

if not isinstance(formatter, dict):
formatter = {col: formatter for col in data.columns}
formatter = dict.fromkeys(data.columns, formatter)

cis = self.columns.get_indexer_for(data.columns)
ris = self.index.get_indexer_for(data.index)
Expand Down Expand Up @@ -1397,7 +1398,7 @@ def format_index(
return self # clear the formatter / revert to default and avoid looping

if not isinstance(formatter, dict):
formatter = {level: formatter for level in levels_}
formatter = dict.fromkeys(levels_, formatter)
else:
formatter = {
obj._get_level_number(level): formatter_
Expand Down Expand Up @@ -1540,7 +1541,7 @@ def relabel_index(

>>> df = pd.DataFrame({"samples": np.random.rand(10)})
>>> styler = df.loc[np.random.randint(0, 10, 3)].style
>>> styler.relabel_index([f"sample{i+1} ({{}})" for i in range(3)])
>>> styler.relabel_index([f"sample{i + 1} ({{}})" for i in range(3)])
... # doctest: +SKIP
samples
sample1 (5) 0.315811
Expand Down Expand Up @@ -1694,7 +1695,7 @@ def format_index_names(
return self # clear the formatter / revert to default and avoid looping

if not isinstance(formatter, dict):
formatter = {level: formatter for level in levels_}
formatter = dict.fromkeys(levels_, formatter)
else:
formatter = {
obj._get_level_number(level): formatter_
Expand Down Expand Up @@ -1814,26 +1815,30 @@ def _get_level_lengths(
levels = index._format_flat(include_name=False)

if hidden_elements is None:
hidden_elements = []
hidden_elements_set = set()
else:
hidden_elements_set = set(hidden_elements)

lengths = {}
if not isinstance(index, MultiIndex):
for i, value in enumerate(levels):
if i not in hidden_elements:
if i not in hidden_elements_set:
lengths[(0, i)] = 1
return lengths

for i, lvl in enumerate(levels):
visible_row_count = 0 # used to break loop due to display trimming
last_label = None
for j, row in enumerate(lvl):
if visible_row_count > max_index:
break
if not sparsify:
# then lengths will always equal 1 since no aggregation.
if j not in hidden_elements:
if j not in hidden_elements_set:
lengths[(i, j)] = 1
visible_row_count += 1
elif (row is not lib.no_default) and (j not in hidden_elements):
elif (row is not lib.no_default) and (j not in hidden_elements_set):
# this element has not been sparsified so must be the start of section
# this element has not been sparsified so must be the start of section
last_label = j
lengths[(i, last_label)] = 1
Expand All @@ -1843,12 +1848,15 @@ def _get_level_lengths(
# later elements are visible
last_label = j
lengths[(i, last_label)] = 0
elif j not in hidden_elements:
elif j not in hidden_elements_set:
# then element must be part of sparsified section and is visible
# then element must be part of sparsified section and is visible
visible_row_count += 1
if visible_row_count > max_index:
break # do not add a length since the render trim limit reached
if lengths[(i, last_label)] == 0:
# Use get to check if lengths exists and is 0; avoid KeyError
if lengths.get((i, last_label), None) == 0:
# if previous iteration was first-of-section but hidden then offset
# if previous iteration was first-of-section but hidden then offset
last_label = j
lengths[(i, last_label)] = 1
Expand Down Expand Up @@ -2503,7 +2511,7 @@ def color(value, user_arg, command, comm_arg):
if value[0] == "#" and len(value) == 7: # color is hex code
return command, f"[HTML]{{{value[1:].upper()}}}{arg}"
if value[0] == "#" and len(value) == 4: # color is short hex code
val = f"{value[1].upper()*2}{value[2].upper()*2}{value[3].upper()*2}"
val = f"{value[1].upper() * 2}{value[2].upper() * 2}{value[3].upper() * 2}"
return command, f"[HTML]{{{val}}}{arg}"
elif value[:3] == "rgb": # color is rgb or rgba
r = re.findall("(?<=\\()[0-9\\s%]+(?=,)", value)[0].strip()
Expand Down