Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions src/bratutils/agreement.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,10 @@ def __init__(self, a):
:type: str
"""
self.text = None
self.start_idx = None
self.end_idx = None
self.frag = None # fragment, usually one pair of indexes, but annotation can be
# composed of several fragments (aka discontinuated annotations)
self.start_idx = None # start index of first fragment
self.end_idx = None # end index of last fragment
self.tag_name = None
self.partial_match = None

Expand All @@ -276,18 +278,23 @@ def __init__(self, a):
self.border_status = False
self.border_match = None

self.text, self.tag_name, self.start_idx, self.end_idx = \
self.text, self.tag_name, self.start_idx, self.end_idx, self.frag = \
self._parse_annotation(a)

@staticmethod
def _parse_annotation(a):
items = a.split("\t")
text = items[2].strip("\n").strip(" ")
subitems = items[1].split(" ")
tag_name = subitems[0]
start_idx = int(subitems[1])
end_idx = int(subitems[2])
return text, tag_name, start_idx, end_idx
tag_name = subitems.pop(0)
subitems = " ".join(subitems).split(";")
frag = []
for idx in subitems:
start_idx, end_idx = idx.split(" ")
frag.append((int(start_idx), int(end_idx)))
start_idx = frag[0][0]
end_idx = frag[len(frag)-1][1]
return text, tag_name, start_idx, end_idx, frag

def reset_markers(self):
"""Resets the comparison marker attributes to default values. The
Expand Down Expand Up @@ -370,8 +377,9 @@ def coincides_with(self, parallel_ann):
:return: True if objects coincide
:rtype: bool
"""
return (self.start_idx == parallel_ann.start_idx and
self.end_idx == parallel_ann.end_idx)

return self.frag == parallel_ann.frag


def contains_ann(self, other_ann):
"""Checks if this object's annotation contains another object's
Expand All @@ -381,8 +389,15 @@ def contains_ann(self, other_ann):
:return: True if this annotaion contains the other annotation
:rtype: bool
"""
return (other_ann.start_idx >= self.start_idx and
other_ann.end_idx <= self.end_idx)
contained_fragments = [False] * len(other_ann.frag)
for i in range(len(other_ann.frag)):
for j in range(len(self.frag)):
if other_ann.frag[i][0] >= self.frag[j][0] and \
other_ann.frag[i][1] <= self.frag[j][1]:
contained_fragments[i] = True
# return True if all fragments of other_ann are contained in self
return contained_fragments == [True] * len(other_ann.frag)


def is_contained_by(self, parallel_ann):
"""Checks if this annotation is contained by a parallel annotation.
Expand All @@ -391,8 +406,8 @@ def is_contained_by(self, parallel_ann):
:return: True if contained in `parallel_ann`
:rtype: bool
"""
return (parallel_ann.start_idx <= self.start_idx and
parallel_ann.end_idx >= self.end_idx)
return self.contains_ann(parallel_ann)


def is_partial_to(self, parallel_ann):
"""Returns `True` if the annotation is a partial match to the parallel
Expand All @@ -404,8 +419,9 @@ def is_partial_to(self, parallel_ann):
:param parallel_ann:
:return:
"""
return (self.start_idx > parallel_ann.start_idx and
self.end_idx == parallel_ann.end_idx and
# TODO really dive into frag (for now, we check start of first fragment and end of last fragment)
return (self.start_idx > parallel_ann.end_idx and
self.start_idx == parallel_ann.end_idx and
self.tag_name == parallel_ann.tag_name)

def get_same_anns(self, parallel_anns):
Expand Down Expand Up @@ -548,15 +564,16 @@ def __eq__(self, ann):
return (self.text == ann.text and
self.start_idx == ann.start_idx and
self.end_idx == ann.end_idx and
self.frag == ann.frag and
self.tag_name == ann.tag_name)

def __str__(self):
atts = [self.tag_name, str(self.start_idx), str(self.end_idx),
atts = [self.tag_name, str(self.start_idx), str(self.end_idx), str(self.frag),
self.text]
return " ".join(atts)

def __repr__(self):
atts = [self.tag_name, str(self.start_idx), str(self.end_idx),
atts = [self.tag_name, str(self.start_idx), str(self.end_idx), str(self.frag),
self.text]
return " ".join(atts)

Expand Down Expand Up @@ -640,13 +657,13 @@ def __init__(self, fp=None, ann_list=None):
self.basename = ""
if fp:
self.basename = os.path.basename(fp)
with open(fp) as doc:
with open(fp, encoding='utf-8') as doc:
for line in doc:
if not line.startswith("#"):
if not line.startswith("#") and not line.startswith("A"): # ignoring Attributes
self.tags.append(Annotation(line))
elif ann_list:
for line in ann_list:
if not line.startswith("#"):
if not line.startswith("#") and not line.startswith("A"):
self.tags.append(Annotation(line))
else:
self.tags = []
Expand Down