Skip to content

Commit 6f29521

Browse files
feat: return opensearch aggregation top hits (#3059)
* return opensearch aggregation hits * add _aggregation_name --------- Co-authored-by: jaidisido <jaidisido@gmail.com>
1 parent e697f65 commit 6f29521

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

awswrangler/opensearch/_read.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,24 @@ def _hit_to_row(hit: Mapping[str, Any]) -> Mapping[str, Any]:
4141
return row
4242

4343

44-
def _search_response_to_documents(response: Mapping[str, Any]) -> list[Mapping[str, Any]]:
45-
return [_hit_to_row(hit) for hit in response.get("hits", {}).get("hits", [])]
46-
47-
48-
def _search_response_to_df(response: Mapping[str, Any] | Any) -> pd.DataFrame:
49-
return pd.DataFrame(_search_response_to_documents(response))
44+
def _search_response_to_documents(
45+
response: Mapping[str, Any], aggregations: list[str] | None = None
46+
) -> list[Mapping[str, Any]]:
47+
hits = response.get("hits", {}).get("hits", [])
48+
if not hits and aggregations:
49+
hits = [
50+
dict(aggregation_hit, _aggregation_name=aggregation_name)
51+
for aggregation_name in aggregations
52+
for aggregation_hit in response.get("aggregations", {})
53+
.get(aggregation_name, {})
54+
.get("hits", {})
55+
.get("hits", [])
56+
]
57+
return [_hit_to_row(hit) for hit in hits]
58+
59+
60+
def _search_response_to_df(response: Mapping[str, Any] | Any, aggregations: list[str] | None = None) -> pd.DataFrame:
61+
return pd.DataFrame(_search_response_to_documents(response=response, aggregations=aggregations))
5062

5163

5264
@_utils.check_optional_dependency(opensearchpy, "opensearchpy")
@@ -128,8 +140,16 @@ def search(
128140
documents = [_hit_to_row(doc) for doc in documents_generator]
129141
df = pd.DataFrame(documents)
130142
else:
143+
aggregations = (
144+
list(search_body.get("aggregations", {}).keys() or search_body.get("aggs", {}).keys())
145+
if search_body
146+
else None
147+
)
131148
response = client.search(index=index, body=search_body, filter_path=filter_path, **kwargs)
132-
df = _search_response_to_df(response)
149+
df = _search_response_to_df(
150+
response=response,
151+
aggregations=aggregations,
152+
)
133153
return df
134154

135155

tests/unit/test_opensearch.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,41 @@ def test_search_scroll(client):
424424
wr.opensearch.delete_index(client, index)
425425

426426

427+
def test_search_aggregation(client):
428+
index = f"test_search_agg_{_get_unique_suffix()}"
429+
kwargs = {} if _is_serverless(client) else {"refresh": "wait_for"}
430+
try:
431+
wr.opensearch.index_documents(
432+
client,
433+
documents=inspections_documents,
434+
index=index,
435+
id_keys=["inspection_id"],
436+
**kwargs,
437+
)
438+
if _is_serverless(client):
439+
# The refresh interval for OpenSearch Serverless is between 10 and 30 seconds
440+
# depending on the size of the request.
441+
time.sleep(30)
442+
df = wr.opensearch.search(
443+
client,
444+
index=index,
445+
search_body={
446+
"aggregations": {
447+
"latest_inspections": {"top_hits": {"sort": [{"inspection_date": {"order": "asc"}}], "size": 1}},
448+
"lowest_inspection_score": {
449+
"top_hits": {"sort": [{"inspection_score": {"order": "asc"}}], "size": 1}
450+
},
451+
}
452+
},
453+
filter_path=["aggregations"],
454+
)
455+
assert df.shape[0] == 2
456+
assert len(df.loc[df["_aggregation_name"] == "latest_inspections"]) == 1
457+
assert len(df.loc[df["_aggregation_name"] == "lowest_inspection_score"]) == 1
458+
finally:
459+
wr.opensearch.delete_index(client, index)
460+
461+
427462
@pytest.mark.parametrize("fetch_size", [None, 1000, 10000])
428463
@pytest.mark.parametrize("fetch_size_param_name", ["size", "fetch_size"])
429464
def test_search_sql(client, fetch_size, fetch_size_param_name):

0 commit comments

Comments
 (0)