Skip to content

Commit 0a63def

Browse files
feat: nested class for temporal annotations support
1 parent 58b30f7 commit 0a63def

File tree

1 file changed

+175
-63
lines changed
  • libs/labelbox/src/labelbox/data/serialization/ndjson

1 file changed

+175
-63
lines changed

libs/labelbox/src/labelbox/data/serialization/ndjson/label.py

Lines changed: 175 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -169,77 +169,189 @@ def _create_video_annotations(
169169
def _create_audio_annotations(
170170
cls, label: Label
171171
) -> Generator[BaseModel, None, None]:
172-
"""Create audio annotations grouped by classification name in v2.py format."""
173-
audio_annotations = defaultdict(list)
172+
"""Create audio annotations with nested classifications (v3-like),
173+
while preserving v2 behavior for non-nested cases.
174174
175-
# Collect audio annotations by name/schema_id
175+
Strategy:
176+
- Group audio annotations by classification (schema_id or name)
177+
- Identify root groups (not fully contained by another group's frames)
178+
- For each root group, build answer items grouped by value with frames
179+
- Recursively attach nested classifications by time containment
180+
"""
181+
182+
# 1) Collect all audio annotations grouped by classification key
183+
# Use feature_schema_id when present, otherwise fall back to name
184+
audio_by_group: Dict[str, List[AudioClassificationAnnotation]] = defaultdict(list)
176185
for annot in label.annotations:
177186
if isinstance(annot, AudioClassificationAnnotation):
178-
audio_annotations[annot.feature_schema_id or annot.name].append(annot)
179-
180-
# Create v2.py format for each classification group
181-
for classification_name, annotation_group in audio_annotations.items():
182-
# Group annotations by value (like v2.py does)
183-
value_groups = defaultdict(list)
184-
185-
for ann in annotation_group:
186-
# Extract value based on classification type for grouping
187-
if hasattr(ann.value, 'answer'):
187+
audio_by_group[annot.feature_schema_id or annot.name].append(annot)
188+
189+
if not audio_by_group:
190+
return
191+
192+
# Helper: produce a user-facing classification name for a group
193+
def group_display_name(group_key: str, anns: List[AudioClassificationAnnotation]) -> str:
194+
# Prefer the first non-empty annotation name
195+
for a in anns:
196+
if a.name:
197+
return a.name
198+
# Fallback to group key (may be schema id)
199+
return group_key
200+
201+
# Helper: compute whether group A is fully contained by any other group by time
202+
def is_group_nested(group_key: str) -> bool:
203+
anns = audio_by_group[group_key]
204+
for ann in anns:
205+
# An annotation is considered nested if there exists any container in other groups
206+
contained = False
207+
for other_key, other_anns in audio_by_group.items():
208+
if other_key == group_key:
209+
continue
210+
for parent in other_anns:
211+
if parent.start_frame <= ann.start_frame and (
212+
parent.end_frame is not None
213+
and ann.end_frame is not None
214+
and parent.end_frame >= ann.end_frame
215+
):
216+
contained = True
217+
break
218+
if contained:
219+
break
220+
if not contained:
221+
# If any annotation in this group is not contained, group is a root
222+
return False
223+
# All annotations were contained somewhere → nested group
224+
return True
225+
226+
# Helper: group annotations by logical value and produce answer entries
227+
def group_by_value(annotations: List[AudioClassificationAnnotation]) -> List[Dict[str, Any]]:
228+
value_buckets: Dict[str, List[AudioClassificationAnnotation]] = defaultdict(list)
229+
230+
for ann in annotations:
231+
# Compute grouping key depending on classification type
232+
if hasattr(ann.value, "answer"):
188233
if isinstance(ann.value.answer, list):
189-
# Checklist classification - convert list to string for grouping
190-
value = str(sorted([item.name for item in ann.value.answer]))
191-
elif hasattr(ann.value.answer, 'name'):
192-
# Radio classification - ann.value.answer is ClassificationAnswer with name
193-
value = ann.value.answer.name
234+
# Checklist: stable key from selected option names
235+
key = str(sorted([opt.name for opt in ann.value.answer]))
236+
elif hasattr(ann.value.answer, "name"):
237+
# Radio: option name
238+
key = ann.value.answer.name
194239
else:
195-
# Text classification
196-
value = ann.value.answer
240+
# Text: the string value
241+
key = ann.value.answer
197242
else:
198-
value = str(ann.value)
199-
200-
# Group by value
201-
value_groups[value].append(ann)
202-
203-
# Create answer items with grouped frames (like v2.py)
204-
answer_items = []
205-
for value, annotations_with_same_value in value_groups.items():
206-
frames = []
207-
for ann in annotations_with_same_value:
208-
frames.append({"start": ann.start_frame, "end": ann.end_frame})
209-
210-
# Extract the actual value for the output (not the grouping key)
211-
first_ann = annotations_with_same_value[0]
212-
213-
# Use different field names based on classification type
214-
if hasattr(first_ann.value, 'answer') and isinstance(first_ann.value.answer, list):
215-
# Checklist - use "name" field (like v2.py)
216-
answer_items.append({
217-
"name": first_ann.value.answer[0].name, # Single item for now
218-
"frames": frames
219-
})
220-
elif hasattr(first_ann.value, 'answer') and hasattr(first_ann.value.answer, 'name'):
221-
# Radio - use "name" field (like v2.py)
222-
answer_items.append({
223-
"name": first_ann.value.answer.name,
224-
"frames": frames
225-
})
243+
key = str(ann.value)
244+
value_buckets[key].append(ann)
245+
246+
entries: List[Dict[str, Any]] = []
247+
for _, anns in value_buckets.items():
248+
first = anns[0]
249+
frames = [{"start": a.start_frame, "end": a.end_frame} for a in anns]
250+
251+
if hasattr(first.value, "answer") and isinstance(first.value.answer, list):
252+
# Checklist: emit one entry per distinct option present in this bucket
253+
# Since bucket is keyed by the combination, take names from first
254+
for opt_name in sorted([o.name for o in first.value.answer]):
255+
entries.append({"name": opt_name, "frames": frames})
256+
elif hasattr(first.value, "answer") and hasattr(first.value.answer, "name"):
257+
# Radio
258+
entries.append({"name": first.value.answer.name, "frames": frames})
226259
else:
227-
# Text - use "value" field (like v2.py)
228-
answer_items.append({
229-
"value": first_ann.value.answer,
230-
"frames": frames
231-
})
232-
233-
# Create a simple Pydantic model for the v2.py format
234-
class AudioNDJSON(BaseModel):
235-
name: str
236-
answer: List[Dict[str, Any]]
237-
dataRow: Dict[str, str]
238-
260+
# Text
261+
entries.append({"value": first.value.answer, "frames": frames})
262+
263+
return entries
264+
265+
# Helper: check if child ann is inside any of the parent frames list
266+
def ann_within_frames(ann: AudioClassificationAnnotation, frames: List[Dict[str, int]]) -> bool:
267+
for fr in frames:
268+
if fr["start"] <= ann.start_frame and (
269+
ann.end_frame is not None and fr["end"] is not None and fr["end"] >= ann.end_frame
270+
):
271+
return True
272+
return False
273+
274+
# Helper: recursively build nested classifications for a specific parent frames list
275+
def build_nested_for_frames(parent_frames: List[Dict[str, int]], exclude_group: str) -> List[Dict[str, Any]]:
276+
nested: List[Dict[str, Any]] = []
277+
278+
# Collect all annotations within parent frames across all groups except the excluded one
279+
all_contained: List[AudioClassificationAnnotation] = []
280+
for gk, ga in audio_by_group.items():
281+
if gk == exclude_group:
282+
continue
283+
all_contained.extend([a for a in ga if ann_within_frames(a, parent_frames)])
284+
285+
def strictly_contains(container: AudioClassificationAnnotation, inner: AudioClassificationAnnotation) -> bool:
286+
if container is inner:
287+
return False
288+
if container.end_frame is None or inner.end_frame is None:
289+
return False
290+
return container.start_frame <= inner.start_frame and container.end_frame >= inner.end_frame and (
291+
container.start_frame < inner.start_frame or container.end_frame > inner.end_frame
292+
)
293+
294+
for group_key, anns in audio_by_group.items():
295+
if group_key == exclude_group:
296+
continue
297+
# Do not nest groups that are roots themselves to avoid duplicating top-level groups inside others
298+
if group_key in root_group_keys:
299+
continue
300+
301+
# Filter annotations that are contained by any parent frame
302+
candidate_anns = [a for a in anns if ann_within_frames(a, parent_frames)]
303+
if not candidate_anns:
304+
continue
305+
306+
# Keep only immediate children (those not strictly contained by another contained annotation)
307+
child_anns = []
308+
for a in candidate_anns:
309+
has_closer_container = any(strictly_contains(b, a) for b in all_contained)
310+
if not has_closer_container:
311+
child_anns.append(a)
312+
if not child_anns:
313+
continue
314+
315+
# Build this child classification block
316+
child_entries = group_by_value(child_anns)
317+
# Recurse: for each answer entry, compute further nested
318+
for entry in child_entries:
319+
entry_frames = entry.get("frames", [])
320+
child_nested = build_nested_for_frames(entry_frames, group_key)
321+
if child_nested:
322+
entry["classifications"] = child_nested
323+
324+
nested.append({
325+
"name": group_display_name(group_key, anns),
326+
"answer": child_entries,
327+
})
328+
329+
return nested
330+
331+
# 2) Determine root groups (not fully contained by other groups)
332+
root_group_keys = [k for k in audio_by_group.keys() if not is_group_nested(k)]
333+
334+
# 3) Emit one NDJSON object per root classification group
335+
class AudioNDJSON(BaseModel):
336+
name: str
337+
answer: List[Dict[str, Any]]
338+
dataRow: Dict[str, str]
339+
340+
for group_key in root_group_keys:
341+
anns = audio_by_group[group_key]
342+
top_entries = group_by_value(anns)
343+
344+
# Attach nested to each top-level answer entry
345+
for entry in top_entries:
346+
frames = entry.get("frames", [])
347+
children = build_nested_for_frames(frames, group_key)
348+
if children:
349+
entry["classifications"] = children
350+
239351
yield AudioNDJSON(
240-
name=classification_name,
241-
answer=answer_items,
242-
dataRow={"globalKey": label.data.global_key}
352+
name=group_display_name(group_key, anns),
353+
answer=top_entries,
354+
dataRow={"globalKey": label.data.global_key},
243355
)
244356

245357

0 commit comments

Comments
 (0)