File tree

7 files changed

+222
-11
lines changed

7 files changed

+222
-11
lines changed
Original file line numberDiff line numberDiff line change
@@ -2054,13 +2054,14 @@ def to_sql_query(
20542054
idx_labels,
20552055
)
20562056

2057-
def cached(self, *, optimize_offsets=False, force: bool = False) -> None:
2057+
# Three strategies,
2058+
def cached(self, *, force: bool = False, session_aware: bool = False) -> None:
20582059
"""Write the block to a session table."""
20592060
# use a heuristic for whether something needs to be cached
20602061
if (not force) and self.session._is_trivially_executable(self.expr):
20612062
return
2062-
if optimize_offsets:
2063-
self.session._cache_with_offsets(self.expr)
2063+
elif session_aware:
2064+
self.session._session_aware_caching(self.expr)
20642065
else:
20652066
self.session._cache_with_cluster_cols(
20662067
self.expr, cluster_cols=self.index_columns
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence
16+
17+
import bigframes.core.expression as ex
18+
import bigframes.operations as ops
19+
20+
COMPARISON_OP_TYPES = tuple(
21+
type(i)
22+
for i in (
23+
ops.eq_op,
24+
ops.eq_null_match_op,
25+
ops.ne_op,
26+
ops.gt_op,
27+
ops.ge_op,
28+
ops.lt_op,
29+
ops.le_op,
30+
)
31+
)
32+
33+
34+
def cluster_cols_for_predicate(predicate: ex.Expression) -> Sequence[str]:
35+
"""Try to determine cluster col candidates that work with given predicates."""
36+
if isinstance(predicate, ex.UnboundVariableExpression):
37+
return [predicate.id]
38+
if isinstance(predicate, ex.OpExpression):
39+
op = predicate.op
40+
if isinstance(op, COMPARISON_OP_TYPES):
41+
return cluster_cols_for_comparison(predicate.inputs[0], predicate.inputs[1])
42+
if isinstance(op, (type(ops.invert_op))):
43+
return cluster_cols_for_predicate(predicate.inputs[0])
44+
if isinstance(op, (type(ops.and_op), type(ops.or_op))):
45+
left_cols = cluster_cols_for_predicate(predicate.inputs[0])
46+
right_cols = cluster_cols_for_predicate(predicate.inputs[1])
47+
return [*left_cols, *[col for col in right_cols if col not in left_cols]]
48+
else:
49+
return []
50+
else:
51+
# Constant
52+
return []
53+
54+
55+
def cluster_cols_for_comparison(
56+
left_ex: ex.Expression, right_ex: ex.Expression
57+
) -> Sequence[str]:
58+
if left_ex.is_const:
59+
if isinstance(right_ex, ex.UnboundVariableExpression):
60+
return [right_ex.id]
61+
elif right_ex.is_const:
62+
if isinstance(left_ex, ex.UnboundVariableExpression):
63+
return [left_ex.id]
64+
return []
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import functools
1717
import itertools
18-
from typing import Callable, Dict, Optional
18+
from typing import Callable, Dict, Optional, Sequence
1919

2020
import bigframes.core.nodes as nodes
2121

@@ -91,6 +91,30 @@ def _node_counts_inner(
9191
)
9292

9393

94+
def count_nodes(forest: Sequence[nodes.BigFrameNode]) -> dict[nodes.BigFrameNode, int]:
95+
def _combine_counts(
96+
left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int]
97+
) -> Dict[nodes.BigFrameNode, int]:
98+
return {
99+
key: left.get(key, 0) + right.get(key, 0)
100+
for key in itertools.chain(left.keys(), right.keys())
101+
}
102+
103+
empty_counts: Dict[nodes.BigFrameNode, int] = {}
104+
105+
@functools.cache
106+
def _node_counts_inner(
107+
subtree: nodes.BigFrameNode,
108+
) -> Dict[nodes.BigFrameNode, int]:
109+
"""Helper function to count occurences of duplicate nodes in a subtree. Considers only nodes in a complexity range"""
110+
child_counts = [_node_counts_inner(child) for child in subtree.child_nodes]
111+
node_counts = functools.reduce(_combine_counts, child_counts, empty_counts)
112+
return _combine_counts(node_counts, {subtree: 1})
113+
114+
counts = [_node_counts_inner(root) for root in forest]
115+
return functools.reduce(_combine_counts, counts, empty_counts)
116+
117+
94118
def replace_nodes(
95119
root: nodes.BigFrameNode,
96120
replacements: dict[nodes.BigFrameNode, nodes.BigFrameNode],
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def order_preserving(self) -> bool:
6060
"""Whether the row operation preserves total ordering. Can be pruned from ordering expressions."""
6161
return False
6262

63+
@property
64+
def pruning_compatible(self) -> bool:
65+
"""Whether the operation preserves locality o"""
66+
return False
67+
6368

6469
@dataclasses.dataclass(frozen=True)
6570
class NaryOp(ScalarOp):
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,39 @@ def head(self, n: int = 5) -> Series:
617617
def tail(self, n: int = 5) -> Series:
618618
return typing.cast(Series, self.iloc[-n:])
619619

620+
def peek(self, n: int = 5, *, force: bool = True) -> pandas.DataFrame:
621+
"""
622+
Preview n arbitrary elements from the series. No guarantees about row selection or ordering.
623+
``Series.peek(force=False)`` will always be very fast, but will not succeed if data requires
624+
full data scanning. Using ``force=True`` will always succeed, but may be perform queries.
625+
Query results will be cached so that future steps will benefit from these queries.
626+
627+
Args:
628+
n (int, default 5):
629+
The number of rows to select from the series. Which N rows are returned is non-deterministic.
630+
force (bool, default True):
631+
If the data cannot be peeked efficiently, the series will instead be fully materialized as part
632+
of the operation if ``force=True``. If ``force=False``, the operation will throw a ValueError.
633+
Returns:
634+
pandas.Series: A pandas Series with n rows.
635+
636+
Raises:
637+
ValueError: If force=False and data cannot be efficiently peeked.
638+
"""
639+
maybe_result = self._block.try_peek(n)
640+
if maybe_result is None:
641+
if force:
642+
self._cached()
643+
maybe_result = self._block.try_peek(n, force=True)
644+
assert maybe_result is not None
645+
else:
646+
raise ValueError(
647+
"Cannot peek efficiently when data has aggregates, joins or window functions applied. Use force=True to fully compute dataframe."
648+
)
649+
as_series = maybe_result.squeeze(axis=1)
650+
as_series.name = self.name
651+
return as_series
652+
620653
def nlargest(self, n: int = 5, keep: str = "first") -> Series:
621654
if keep not in ("first", "last", "all"):
622655
raise ValueError("'keep must be one of 'first', 'last', or 'all'")
@@ -1400,7 +1433,7 @@ def apply(
14001433

14011434
# return Series with materialized result so that any error in the remote
14021435
# function is caught early
1403-
materialized_series = result_series._cached()
1436+
materialized_series = result_series._cached(session_aware=False)
14041437
return materialized_series
14051438

14061439
def combine(
@@ -1775,10 +1808,11 @@ def cache(self):
17751808
Returns:
17761809
Series: Self
17771810
"""
1778-
return self._cached(force=True)
1811+
# Do not use session-aware cashing if user-requested
1812+
return self._cached(force=True, session_aware=False)
17791813

1780-
def _cached(self, *, force: bool = True) -> Series:
1781-
self._block.cached(force=force)
1814+
def _cached(self, *, force: bool = True, session_aware: bool = True) -> Series:
1815+
self._block.cached(force=force, session_aware=session_aware)
17821816
return self
17831817

17841818
def _optimize_query_complexity(self):
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from __future__ import annotations
1818

19-
import collections.abc
2019
import copy
2120
import datetime
2221
import logging
@@ -81,10 +80,12 @@
8180
import bigframes.core as core
8281
import bigframes.core.blocks as blocks
8382
import bigframes.core.compile
83+
import bigframes.core.expression as ex
8484
import bigframes.core.guid
8585
import bigframes.core.nodes as nodes
8686
from bigframes.core.ordering import IntegerEncoding
8787
import bigframes.core.ordering as order
88+
import bigframes.core.pruning
8889
import bigframes.core.tree_properties as traversals
8990
import bigframes.core.tree_properties as tree_properties
9091
import bigframes.core.utils as utils
@@ -326,13 +327,15 @@ def session_id(self):
326327
@property
327328
def objects(
328329
self,
329-
) -> collections.abc.Set[
330+
) -> Tuple[
330331
Union[
331332
bigframes.core.indexes.Index, bigframes.series.Series, dataframe.DataFrame
332333
]
333334
]:
335+
still_alive = [i for i in self._objects if i() is not None]
336+
self._objects = still_alive
334337
# Create a set with strong references, be careful not to hold onto this needlessly, as will prevent garbage collection.
335-
return set(i() for i in self._objects if i() is not None) # type: ignore
338+
return tuple(i() for i in self._objects if i() is not None) # type: ignore
336339

337340
@property
338341
def _project(self):
@@ -1913,6 +1916,51 @@ def _cache_with_offsets(self, array_value: core.ArrayValue):
19131916
).node
19141917
self._cached_executions[array_value.node] = cached_replacement
19151918

1919+
def _session_aware_caching(self, array_value: core.ArrayValue) -> None:
1920+
# this is the occurence count across the whole session
1921+
node_counts = traversals.count_nodes(
1922+
[obj._block.expr.node for obj in self.objects]
1923+
)
1924+
de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode)
1925+
caching_target = array_value.node
1926+
caching_target_count = node_counts.get(caching_target, 0)
1927+
1928+
cur_node = array_value.node
1929+
1930+
# TODO: Identify filtered columns from FilterNode and use as cluster col(s)
1931+
filters: list[
1932+
ex.Expression
1933+
] = [] # accumulate filters into this as traverse downwards
1934+
cluster_col: Optional[str] = None
1935+
while isinstance(cur_node, de_cachable_types):
1936+
if isinstance(cur_node, nodes.FilterNode):
1937+
filters.append(cur_node.predicate)
1938+
if isinstance(cur_node, nodes.ProjectionNode):
1939+
bindings = {name: expr for expr, name in cur_node.assignments}
1940+
filters = [i.bind_all_variables(bindings) for i in filters]
1941+
1942+
cur_node = cur_node.child
1943+
cur_node_count = node_counts.get(cur_node, 0)
1944+
if cur_node_count > caching_target_count:
1945+
caching_target, caching_target_count = cur_node, cur_node_count
1946+
cluster_col = None
1947+
# Just pick the first cluster-compatible predicate
1948+
for predicate in filters:
1949+
# Cluster cols only consider the target object and not other sesssion objects
1950+
cluster_cols = bigframes.core.pruning.cluster_cols_for_predicate(
1951+
predicate
1952+
)
1953+
if len(cluster_cols) > 0:
1954+
cluster_col = cluster_cols[0]
1955+
continue
1956+
1957+
if cluster_col:
1958+
self._cache_with_cluster_cols(
1959+
core.ArrayValue(caching_target), [cluster_col]
1960+
)
1961+
else:
1962+
self._cache_with_offsets(core.ArrayValue(caching_target))
1963+
19161964
def _simplify_with_caching(self, array_value: core.ArrayValue):
19171965
"""Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces."""
19181966
# Apply existing caching first
Original file line numberDiff line numberDiff line change
@@ -1936,6 +1936,41 @@ def test_head_then_series_operation(scalars_dfs):
19361936
)
19371937

19381938

1939+
def test_series_peek(scalars_dfs):
1940+
scalars_df, scalars_pandas_df = scalars_dfs
1941+
peek_result = scalars_df["float64_col"].peek(n=3, force=False)
1942+
pd.testing.assert_series_equal(
1943+
peek_result,
1944+
scalars_pandas_df["float64_col"].reindex_like(peek_result),
1945+
)
1946+
1947+
1948+
def test_series_peek_filtered(scalars_dfs):
1949+
scalars_df, scalars_pandas_df = scalars_dfs
1950+
peek_result = scalars_df[scalars_df.int64_col > 0]["float64_col"].peek(
1951+
n=3, force=False
1952+
)
1953+
pd_result = scalars_pandas_df[scalars_pandas_df.int64_col > 0]["float64_col"]
1954+
pd.testing.assert_series_equal(
1955+
peek_result,
1956+
pd_result.reindex_like(peek_result),
1957+
)
1958+
1959+
1960+
def test_series_peek_force(scalars_dfs):
1961+
scalars_df, scalars_pandas_df = scalars_dfs
1962+
1963+
cumsum_df = scalars_df[["int64_col", "int64_too"]].cumsum()
1964+
df_filtered = cumsum_df[cumsum_df.int64_col > 0]["int64_too"]
1965+
peek_result = df_filtered.peek(n=3, force=True)
1966+
pd_cumsum_df = scalars_pandas_df[["int64_col", "int64_too"]].cumsum()
1967+
pd_result = pd_cumsum_df[pd_cumsum_df.int64_col > 0]["int64_too"]
1968+
pd.testing.assert_series_equal(
1969+
peek_result,
1970+
pd_result.reindex_like(peek_result),
1971+
)
1972+
1973+
19391974
def test_shift(scalars_df_index, scalars_pandas_df_index):
19401975
col_name = "int64_col"
19411976
bf_result = scalars_df_index[col_name].shift().to_pandas()

0 commit comments

Comments
 (0)