Remove old TODOs

This commit is contained in:
Paul O'Leary McCann 2022-07-06 17:23:41 +09:00
parent da9c379355
commit c4de3e51a2
3 changed files with 13 additions and 17 deletions

View File

@ -35,7 +35,6 @@ def build_span_predictor(
),
convert_inputs=convert_span_predictor_inputs,
)
# TODO use proper parameter for prefix
head_info = build_get_head_metadata(prefix)
model = (tok2vec & head_info) >> span_predictor
@ -96,7 +95,6 @@ def predict_span_clusters(
def build_get_head_metadata(prefix):
# TODO this name is awful, fix it
model = Model(
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
)
@ -142,7 +140,6 @@ class SpanPredictor(torch.nn.Module):
raise ValueError("max_distance has to be an even number")
# input size = single token size
# 64 = probably distance emb size
# TODO check that dist_emb_size use is correct
self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
torch.nn.ReLU(),
@ -159,7 +156,6 @@ class SpanPredictor(torch.nn.Module):
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
)
# TODO make embeddings size a parameter
self.max_distance = max_distance
# handle distances between +-(max_distance - 2 / 2)
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)

View File

@ -95,7 +95,7 @@ def make_coref(
class CoreferenceResolver(TrainablePipe):
"""Pipeline component for coreference resolution.
DOCS: https://spacy.io/api/coref (TODO)
DOCS: https://spacy.io/api/coref
"""
def __init__(
@ -118,8 +118,10 @@ class CoreferenceResolver(TrainablePipe):
are stored in.
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
coref clusters in.
scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_coref_clusters.
DOCS: https://spacy.io/api/coref#init (TODO)
DOCS: https://spacy.io/api/coref#init
"""
self.vocab = vocab
self.model = model
@ -133,11 +135,12 @@ class CoreferenceResolver(TrainablePipe):
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
Return the list of predicted clusters.
docs (Iterable[Doc]): The documents to predict.
RETURNS: The models prediction for each document.
RETURNS (List[MentionClusters]): The model's prediction for each document.
DOCS: https://spacy.io/api/coref#predict (TODO)
DOCS: https://spacy.io/api/coref#predict
"""
out = []
for doc in docs:
@ -163,7 +166,7 @@ class CoreferenceResolver(TrainablePipe):
docs (Iterable[Doc]): The documents to modify.
clusters: The span clusters, produced by CoreferenceResolver.predict.
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
DOCS: https://spacy.io/api/coref#set_annotations
"""
docs = list(docs)
if len(docs) != len(clusters_by_doc):
@ -204,7 +207,7 @@ class CoreferenceResolver(TrainablePipe):
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/coref#update (TODO)
DOCS: https://spacy.io/api/coref#update
"""
if losses is None:
losses = {}
@ -218,12 +221,10 @@ class CoreferenceResolver(TrainablePipe):
total_loss = 0
for eg in examples:
# TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss
# TODO check shape here
backprop((d_scores, mention_idx))
if sgd is not None:
@ -257,7 +258,7 @@ class CoreferenceResolver(TrainablePipe):
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/coref#get_loss (TODO)
DOCS: https://spacy.io/api/coref#get_loss
"""
ops = self.model.ops
xp = ops.xp
@ -270,9 +271,8 @@ class CoreferenceResolver(TrainablePipe):
clusters = get_clusters_from_doc(example.reference)
span_idxs = create_head_span_idxs(ops, len(example.predicted))
gscores = create_gold_scores(span_idxs, clusters)
# TODO fix type here. This is bools but asarray2f wants ints.
# Note on type here. This is bools but asarray2f wants ints.
gscores = ops.asarray2f(gscores) # type: ignore
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
# now add the placeholder
gold_placeholder = ~top_gscores.any(axis=1).T
@ -304,7 +304,7 @@ class CoreferenceResolver(TrainablePipe):
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
DOCS: https://spacy.io/api/coref#initialize (TODO)
DOCS: https://spacy.io/api/coref#initialize
"""
validate_get_examples(get_examples, "CoreferenceResolver.initialize")

View File

@ -383,7 +383,7 @@ class EntityLinker(TrainablePipe):
no prediction.
docs (Iterable[Doc]): The documents to predict.
RETURNS (List[str]): The models prediction for each document.
RETURNS (List[str]): The model's prediction for each document.
DOCS: https://spacy.io/api/entitylinker#predict
"""