Remove unused functions

This commit is contained in:
Paul O'Leary McCann 2022-03-16 14:38:11 +09:00
parent d0ae2590db
commit 5650853c0f

View File

@ -56,21 +56,6 @@ def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
return out return out
def topk(xp, arr, k, axis=1):
"""Given an array and a k value, give the top values and idxs for each row."""
part = xp.argpartition(arr, -k, axis=axis)
idxs = xp.flip(part)[:, :k]
vals = xp.take_along_axis(arr, idxs, axis=axis)
sidxs = xp.argsort(-vals, axis=axis)
# map these idxs back to the original
oidxs = xp.take_along_axis(idxs, sidxs, axis=axis)
svals = xp.take_along_axis(vals, sidxs, axis=axis)
return svals, oidxs
# from model.py, refactored to be non-member # from model.py, refactored to be non-member
def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores): def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores):
"""Get the ID of the antecedent for each span. -1 if no antecedent.""" """Get the ID of the antecedent for each span. -1 if no antecedent."""
@ -124,55 +109,6 @@ def get_predicted_clusters(
return predicted_clusters return predicted_clusters
def get_sentence_map(doc: Doc):
"""For the given span, return a list of sentence indexes."""
if doc.has_annotation("SENT_START"):
si = 0
out = []
for sent in doc.sents:
for _ in sent:
out.append(si)
si += 1
return out
else:
# If there are no sents then just return dummy values.
# Shouldn't happen in general training, but typical in init.
return [0] * len(doc)
def get_candidate_mentions(
doc: Doc, max_span_width: int = 20
) -> Tuple[List[int], List[int]]:
"""Given a Doc, return candidate mentions.
This isn't a trainable layer, it just returns raw candidates.
"""
# XXX Note that in coref-hoi the indexes are designed so you actually want [i:j+1], but here
# we're using [i:j], which is more natural.
sentence_map = get_sentence_map(doc)
begins = []
ends = []
for tok in doc:
si = sentence_map[tok.i] # sentence index
for ii in range(1, max_span_width):
ei = tok.i + ii # end index
# Note: this matches slice syntax, so the token index is one less
if ei > len(doc) or sentence_map[ei - 1] != si:
break
begins.append(tok.i)
ends.append(ei)
return (begins, ends)
@registry.misc("spacy.CorefCandidateGenerator.v1")
def create_mention_generator() -> Any:
return get_candidate_mentions
def select_non_crossing_spans( def select_non_crossing_spans(
idxs: List[int], starts: List[int], ends: List[int], limit: int idxs: List[int], starts: List[int], ends: List[int], limit: int