Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 31 additions & 29 deletions udapi/core/coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,42 +180,44 @@ def create_coref_cluster(head, cluster_id=None, cluster_type=None, **kwargs):

def load_coref_from_misc(doc):
clusters = {}
for node in doc.nodes:
index, index_str = 0, ""
cluster_id = node.misc["ClusterId"]
if not cluster_id:
index, index_str = 1, "[1]"
cluster_id = node.misc["ClusterId[1]"]
while cluster_id:
cluster = clusters.get(cluster_id)
if cluster is None:
cluster = CorefCluster(cluster_id)
clusters[cluster_id] = cluster
mention = CorefMention(node, cluster)
if node.misc["MentionSpan" + index_str]:
mention.span = node.misc["MentionSpan" + index_str]
cluster_type = node.misc["ClusterType" + index_str]
if cluster_type is not None:
if cluster.cluster_type is not None and cluster_type != cluster.cluster_type:
logging.warning(f"cluster_type mismatch in {node}: {cluster.cluster_type} != {cluster_type}")
cluster.cluster_type = cluster_type
# TODO deserialize Bridging and SplitAnte
mention._bridging = node.misc["Bridging" + index_str]
cluster._split_ante = node.misc["SplitAnte" + index_str]
index += 1
index_str = f"[{index}]"
cluster_id = node.misc["ClusterId" + index_str]
for tree in doc.trees:
for node in tree.descendants_and_empty:
index, index_str = 0, ""
cluster_id = node.misc["ClusterId"]
if not cluster_id:
index, index_str = 1, "[1]"
cluster_id = node.misc["ClusterId[1]"]
while cluster_id:
cluster = clusters.get(cluster_id)
if cluster is None:
cluster = CorefCluster(cluster_id)
clusters[cluster_id] = cluster
mention = CorefMention(node, cluster)
if node.misc["MentionSpan" + index_str]:
mention.span = node.misc["MentionSpan" + index_str]
cluster_type = node.misc["ClusterType" + index_str]
if cluster_type is not None:
if cluster.cluster_type is not None and cluster_type != cluster.cluster_type:
logging.warning(f"cluster_type mismatch in {node}: {cluster.cluster_type} != {cluster_type}")
cluster.cluster_type = cluster_type
# TODO deserialize Bridging and SplitAnte
mention._bridging = node.misc["Bridging" + index_str]
cluster._split_ante = node.misc["SplitAnte" + index_str]
index += 1
index_str = f"[{index}]"
cluster_id = node.misc["ClusterId" + index_str]
doc._coref_clusters = clusters


def store_coref_to_misc(doc):
if not doc._coref_clusters:
return
attrs = ("ClusterId", "MentionSpan", "ClusterType", "Bridging", "SplitAnte")
for node in doc.nodes:
for key in list(node.misc):
if any(re.match(attr + r'(\[\d+\])?$', key) for attr in attrs):
del node.misc[key]
for tree in doc.trees:
for node in tree.descendants_and_empty:
for key in list(node.misc):
if any(re.match(attr + r'(\[\d+\])?$', key) for attr in attrs):
del node.misc[key]
for cluster in doc._coref_clusters.values():
for mention in cluster.mentions:
head = mention.head
Expand Down