mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
Remove old TODOs
This commit is contained in:
parent
da9c379355
commit
c4de3e51a2
|
@ -35,7 +35,6 @@ def build_span_predictor(
|
||||||
),
|
),
|
||||||
convert_inputs=convert_span_predictor_inputs,
|
convert_inputs=convert_span_predictor_inputs,
|
||||||
)
|
)
|
||||||
# TODO use proper parameter for prefix
|
|
||||||
head_info = build_get_head_metadata(prefix)
|
head_info = build_get_head_metadata(prefix)
|
||||||
model = (tok2vec & head_info) >> span_predictor
|
model = (tok2vec & head_info) >> span_predictor
|
||||||
|
|
||||||
|
@ -96,7 +95,6 @@ def predict_span_clusters(
|
||||||
|
|
||||||
|
|
||||||
def build_get_head_metadata(prefix):
|
def build_get_head_metadata(prefix):
|
||||||
# TODO this name is awful, fix it
|
|
||||||
model = Model(
|
model = Model(
|
||||||
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
|
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
|
||||||
)
|
)
|
||||||
|
@ -142,7 +140,6 @@ class SpanPredictor(torch.nn.Module):
|
||||||
raise ValueError("max_distance has to be an even number")
|
raise ValueError("max_distance has to be an even number")
|
||||||
# input size = single token size
|
# input size = single token size
|
||||||
# 64 = probably distance emb size
|
# 64 = probably distance emb size
|
||||||
# TODO check that dist_emb_size use is correct
|
|
||||||
self.ffnn = torch.nn.Sequential(
|
self.ffnn = torch.nn.Sequential(
|
||||||
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
@ -159,7 +156,6 @@ class SpanPredictor(torch.nn.Module):
|
||||||
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
||||||
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
|
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
|
||||||
)
|
)
|
||||||
# TODO make embeddings size a parameter
|
|
||||||
self.max_distance = max_distance
|
self.max_distance = max_distance
|
||||||
# handle distances between +-(max_distance - 2 / 2)
|
# handle distances between +-(max_distance - 2 / 2)
|
||||||
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
|
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
|
||||||
|
|
|
@ -95,7 +95,7 @@ def make_coref(
|
||||||
class CoreferenceResolver(TrainablePipe):
|
class CoreferenceResolver(TrainablePipe):
|
||||||
"""Pipeline component for coreference resolution.
|
"""Pipeline component for coreference resolution.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref (TODO)
|
DOCS: https://spacy.io/api/coref
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -118,8 +118,10 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
are stored in.
|
are stored in.
|
||||||
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
|
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
|
||||||
coref clusters in.
|
coref clusters in.
|
||||||
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
|
Scorer.score_coref_clusters.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#init (TODO)
|
DOCS: https://spacy.io/api/coref#init
|
||||||
"""
|
"""
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -133,11 +135,12 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
Return the list of predicted clusters.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to predict.
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
RETURNS: The models prediction for each document.
|
RETURNS (List[MentionClusters]): The model's prediction for each document.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#predict (TODO)
|
DOCS: https://spacy.io/api/coref#predict
|
||||||
"""
|
"""
|
||||||
out = []
|
out = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
@ -163,7 +166,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
clusters: The span clusters, produced by CoreferenceResolver.predict.
|
clusters: The span clusters, produced by CoreferenceResolver.predict.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
|
DOCS: https://spacy.io/api/coref#set_annotations
|
||||||
"""
|
"""
|
||||||
docs = list(docs)
|
docs = list(docs)
|
||||||
if len(docs) != len(clusters_by_doc):
|
if len(docs) != len(clusters_by_doc):
|
||||||
|
@ -204,7 +207,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
Updated using the component name as the key.
|
Updated using the component name as the key.
|
||||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#update (TODO)
|
DOCS: https://spacy.io/api/coref#update
|
||||||
"""
|
"""
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -218,12 +221,10 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
# TODO check this causes no issues (in practice it runs)
|
|
||||||
preds, backprop = self.model.begin_update([eg.predicted])
|
preds, backprop = self.model.begin_update([eg.predicted])
|
||||||
score_matrix, mention_idx = preds
|
score_matrix, mention_idx = preds
|
||||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
|
||||||
backprop((d_scores, mention_idx))
|
backprop((d_scores, mention_idx))
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
|
@ -257,7 +258,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
scores: Scores representing the model's predictions.
|
scores: Scores representing the model's predictions.
|
||||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#get_loss (TODO)
|
DOCS: https://spacy.io/api/coref#get_loss
|
||||||
"""
|
"""
|
||||||
ops = self.model.ops
|
ops = self.model.ops
|
||||||
xp = ops.xp
|
xp = ops.xp
|
||||||
|
@ -270,9 +271,8 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
clusters = get_clusters_from_doc(example.reference)
|
clusters = get_clusters_from_doc(example.reference)
|
||||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||||
gscores = create_gold_scores(span_idxs, clusters)
|
gscores = create_gold_scores(span_idxs, clusters)
|
||||||
# TODO fix type here. This is bools but asarray2f wants ints.
|
# Note on type here. This is bools but asarray2f wants ints.
|
||||||
gscores = ops.asarray2f(gscores) # type: ignore
|
gscores = ops.asarray2f(gscores) # type: ignore
|
||||||
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
|
||||||
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||||
# now add the placeholder
|
# now add the placeholder
|
||||||
gold_placeholder = ~top_gscores.any(axis=1).T
|
gold_placeholder = ~top_gscores.any(axis=1).T
|
||||||
|
@ -304,7 +304,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
returns a representative sample of gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects.
|
||||||
nlp (Language): The current nlp object the component is part of.
|
nlp (Language): The current nlp object the component is part of.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#initialize (TODO)
|
DOCS: https://spacy.io/api/coref#initialize
|
||||||
"""
|
"""
|
||||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||||
|
|
||||||
|
|
|
@ -383,7 +383,7 @@ class EntityLinker(TrainablePipe):
|
||||||
no prediction.
|
no prediction.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to predict.
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
RETURNS (List[str]): The models prediction for each document.
|
RETURNS (List[str]): The model's prediction for each document.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entitylinker#predict
|
DOCS: https://spacy.io/api/entitylinker#predict
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user