Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions udapi/core/coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def span(self, new_span):
self.words = span_to_nodes(self._head.root, new_span)


@functools.total_ordering
class CorefCluster(object):
"""Class for representing all mentions of a given entity."""
__slots__ = ['_cluster_id', '_mentions', 'cluster_type', 'split_ante']
Expand All @@ -109,6 +110,20 @@ def __init__(self, cluster_id, cluster_type=None):
self.cluster_type = cluster_type
self.split_ante = []

def __lt__(self, other):
"""Does this CorefCluster precedes (word-order wise) the `other` cluster?

This method defines a total ordering of all clusters
by the first mention of each cluster (see `CorefMention.__lt__`).
Only if one of the clusters has no mentions (which should not happen normally),
the ordering is defined by the `cluster_id` (lexicographically).
If cluster IDs are not important, it is recommended to use block
`corefud.IndexClusters` to re-name cluster IDs in accordance with this cluster ordering.
"""
if not self.mentions or not other.mentions:
return self._cluster_id < other._cluster_id
return self.mentions[0] < other.mentions[0]

@property
def cluster_id(self):
return self._cluster_id
Expand Down Expand Up @@ -299,7 +314,7 @@ def load_coref_from_misc(doc):
index += 1
index_str = f"[{index}]"
cluster_id = node.misc["ClusterId" + index_str]
doc._coref_clusters = clusters
doc._coref_clusters = {k: clusters[k] for k in sorted(clusters)}


def store_coref_to_misc(doc):
Expand All @@ -310,7 +325,7 @@ def store_coref_to_misc(doc):
for key in list(node.misc):
if any(re.match(attr + r'(\[\d+\])?$', key) for attr in attrs):
del node.misc[key]
for cluster in doc._coref_clusters.values():
for cluster in sorted(doc._coref_clusters.values()):
for mention in cluster.mentions:
head = mention.head
if head.misc["ClusterId"]:
Expand Down