mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Split span predictor component into its own file
This runs. The imports in both of the split files could probably use a close check to remove extras.
This commit is contained in:
parent
117a9ef2bf
commit
f852c5cea4
|
@ -1,5 +1,6 @@
|
||||||
from .attributeruler import AttributeRuler
|
from .attributeruler import AttributeRuler
|
||||||
from .coref import CoreferenceResolver
|
from .coref import CoreferenceResolver
|
||||||
|
from .span_predictor import SpanPredictor
|
||||||
from .dep_parser import DependencyParser
|
from .dep_parser import DependencyParser
|
||||||
from .edit_tree_lemmatizer import EditTreeLemmatizer
|
from .edit_tree_lemmatizer import EditTreeLemmatizer
|
||||||
from .entity_linker import EntityLinker
|
from .entity_linker import EntityLinker
|
||||||
|
|
|
@ -56,7 +56,7 @@ window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
depth = 2
|
depth = 2
|
||||||
"""
|
"""
|
||||||
DEFAULT_MODEL = Config().from_str(default_config)["model"]
|
DEFAULT_COREF_MODEL = Config().from_str(default_config)["model"]
|
||||||
|
|
||||||
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
||||||
assigns=["doc.spans"],
|
assigns=["doc.spans"],
|
||||||
requires=["doc.spans"],
|
requires=["doc.spans"],
|
||||||
default_config={
|
default_config={
|
||||||
"model": DEFAULT_MODEL,
|
"model": DEFAULT_COREF_MODEL,
|
||||||
"span_cluster_prefix": DEFAULT_CLUSTER_PREFIX,
|
"span_cluster_prefix": DEFAULT_CLUSTER_PREFIX,
|
||||||
},
|
},
|
||||||
default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None},
|
default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None},
|
||||||
|
@ -375,260 +375,3 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
|
||||||
default_span_predictor_config = """
|
|
||||||
[model]
|
|
||||||
@architectures = "spacy.SpanPredictor.v1"
|
|
||||||
hidden_size = 1024
|
|
||||||
dist_emb_size = 64
|
|
||||||
|
|
||||||
[model.tok2vec]
|
|
||||||
@architectures = "spacy.Tok2Vec.v2"
|
|
||||||
|
|
||||||
[model.tok2vec.embed]
|
|
||||||
@architectures = "spacy.MultiHashEmbed.v1"
|
|
||||||
width = 64
|
|
||||||
rows = [2000, 2000, 1000, 1000, 1000, 1000]
|
|
||||||
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
|
||||||
include_static_vectors = false
|
|
||||||
|
|
||||||
[model.tok2vec.encode]
|
|
||||||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
|
||||||
width = ${model.tok2vec.embed.width}
|
|
||||||
window_size = 1
|
|
||||||
maxout_pieces = 3
|
|
||||||
depth = 2
|
|
||||||
"""
|
|
||||||
DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"]
|
|
||||||
|
|
||||||
@Language.factory(
|
|
||||||
"span_predictor",
|
|
||||||
assigns=["doc.spans"],
|
|
||||||
requires=["doc.spans"],
|
|
||||||
default_config={
|
|
||||||
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
|
|
||||||
"input_prefix": "coref_head_clusters",
|
|
||||||
"output_prefix": "coref_clusters",
|
|
||||||
},
|
|
||||||
default_score_weights={"span_accuracy": 1.0},
|
|
||||||
)
|
|
||||||
def make_span_predictor(
|
|
||||||
nlp: Language,
|
|
||||||
name: str,
|
|
||||||
model,
|
|
||||||
input_prefix: str = "coref_head_clusters",
|
|
||||||
output_prefix: str = "coref_clusters",
|
|
||||||
) -> "SpanPredictor":
|
|
||||||
"""Create a SpanPredictor component."""
|
|
||||||
return SpanPredictor(nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix)
|
|
||||||
|
|
||||||
class SpanPredictor(TrainablePipe):
|
|
||||||
"""Pipeline component to resolve one-token spans to full spans.
|
|
||||||
|
|
||||||
Used in coreference resolution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab: Vocab,
|
|
||||||
model: Model,
|
|
||||||
name: str = "span_predictor",
|
|
||||||
*,
|
|
||||||
input_prefix: str = "coref_head_clusters",
|
|
||||||
output_prefix: str = "coref_clusters",
|
|
||||||
) -> None:
|
|
||||||
self.vocab = vocab
|
|
||||||
self.model = model
|
|
||||||
self.name = name
|
|
||||||
self.input_prefix = input_prefix
|
|
||||||
self.output_prefix = output_prefix
|
|
||||||
|
|
||||||
self.cfg = {}
|
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
|
||||||
# for now pretend there's just one doc
|
|
||||||
|
|
||||||
out = []
|
|
||||||
for doc in docs:
|
|
||||||
# TODO check shape here
|
|
||||||
span_scores = self.model.predict([doc])
|
|
||||||
if span_scores.size:
|
|
||||||
# the information about clustering has to come from the input docs
|
|
||||||
# first let's convert the scores to a list of span idxs
|
|
||||||
start_scores = span_scores[:, :, 0]
|
|
||||||
end_scores = span_scores[:, :, 1]
|
|
||||||
starts = start_scores.argmax(axis=1)
|
|
||||||
ends = end_scores.argmax(axis=1)
|
|
||||||
|
|
||||||
# TODO check start < end
|
|
||||||
|
|
||||||
# get the old clusters (shape will be preserved)
|
|
||||||
clusters = doc2clusters(doc, self.input_prefix)
|
|
||||||
cidx = 0
|
|
||||||
out_clusters = []
|
|
||||||
for cluster in clusters:
|
|
||||||
ncluster = []
|
|
||||||
for mention in cluster:
|
|
||||||
ncluster.append((starts[cidx], ends[cidx]))
|
|
||||||
cidx += 1
|
|
||||||
out_clusters.append(ncluster)
|
|
||||||
else:
|
|
||||||
out_clusters = []
|
|
||||||
out.append(out_clusters)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
|
||||||
for doc, clusters in zip(docs, clusters_by_doc):
|
|
||||||
for ii, cluster in enumerate(clusters):
|
|
||||||
spans = [doc[mm[0]:mm[1]] for mm in cluster]
|
|
||||||
doc.spans[f"{self.output_prefix}_{ii}"] = spans
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
examples: Iterable[Example],
|
|
||||||
*,
|
|
||||||
drop: float = 0.0,
|
|
||||||
sgd: Optional[Optimizer] = None,
|
|
||||||
losses: Optional[Dict[str, float]] = None,
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""Learn from a batch of documents and gold-standard information,
|
|
||||||
updating the pipe's model. Delegates to predict and get_loss.
|
|
||||||
"""
|
|
||||||
if losses is None:
|
|
||||||
losses = {}
|
|
||||||
losses.setdefault(self.name, 0.0)
|
|
||||||
validate_examples(examples, "SpanPredictor.update")
|
|
||||||
if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
|
|
||||||
# Handle cases where there are no tokens in any docs.
|
|
||||||
return losses
|
|
||||||
set_dropout_rate(self.model, drop)
|
|
||||||
|
|
||||||
total_loss = 0
|
|
||||||
for eg in examples:
|
|
||||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
|
||||||
# FIXME, this only happens once in the first 1000 docs of OntoNotes
|
|
||||||
# and I'm not sure yet why.
|
|
||||||
if span_scores.size:
|
|
||||||
loss, d_scores = self.get_loss([eg], span_scores)
|
|
||||||
total_loss += loss
|
|
||||||
# TODO check shape here
|
|
||||||
backprop((d_scores))
|
|
||||||
|
|
||||||
if sgd is not None:
|
|
||||||
self.finish_update(sgd)
|
|
||||||
losses[self.name] += total_loss
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def rehearse(
|
|
||||||
self,
|
|
||||||
examples: Iterable[Example],
|
|
||||||
*,
|
|
||||||
drop: float = 0.0,
|
|
||||||
sgd: Optional[Optimizer] = None,
|
|
||||||
losses: Optional[Dict[str, float]] = None,
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
# TODO this should be added later
|
|
||||||
raise NotImplementedError(
|
|
||||||
Errors.E931.format(
|
|
||||||
parent="SpanPredictor", method="add_label", name=self.name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_label(self, label: str) -> int:
|
|
||||||
"""Technically this method should be implemented from TrainablePipe,
|
|
||||||
but it is not relevant for this component.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
Errors.E931.format(
|
|
||||||
parent="SpanPredictor", method="add_label", name=self.name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_loss(
|
|
||||||
self,
|
|
||||||
examples: Iterable[Example],
|
|
||||||
span_scores: Floats3d,
|
|
||||||
):
|
|
||||||
ops = self.model.ops
|
|
||||||
|
|
||||||
# NOTE This is doing fake batching, and should always get a list of one example
|
|
||||||
assert len(examples) == 1, "Only fake batching is supported."
|
|
||||||
# starts and ends are gold starts and ends (Ints1d)
|
|
||||||
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
|
||||||
for eg in examples:
|
|
||||||
starts = []
|
|
||||||
ends = []
|
|
||||||
for key, sg in eg.reference.spans.items():
|
|
||||||
if key.startswith(self.output_prefix):
|
|
||||||
for mention in sg:
|
|
||||||
starts.append(mention.start)
|
|
||||||
ends.append(mention.end)
|
|
||||||
|
|
||||||
starts = self.model.ops.xp.asarray(starts)
|
|
||||||
ends = self.model.ops.xp.asarray(ends)
|
|
||||||
start_scores = span_scores[:, :, 0]
|
|
||||||
end_scores = span_scores[:, :, 1]
|
|
||||||
n_classes = start_scores.shape[1]
|
|
||||||
start_probs = ops.softmax(start_scores, axis=1)
|
|
||||||
end_probs = ops.softmax(end_scores, axis=1)
|
|
||||||
start_targets = to_categorical(starts, n_classes)
|
|
||||||
end_targets = to_categorical(ends, n_classes)
|
|
||||||
start_grads = (start_probs - start_targets)
|
|
||||||
end_grads = (end_probs - end_targets)
|
|
||||||
grads = ops.xp.stack((start_grads, end_grads), axis=2)
|
|
||||||
loss = float((grads ** 2).sum())
|
|
||||||
return loss, grads
|
|
||||||
|
|
||||||
def initialize(
|
|
||||||
self,
|
|
||||||
get_examples: Callable[[], Iterable[Example]],
|
|
||||||
*,
|
|
||||||
nlp: Optional[Language] = None,
|
|
||||||
) -> None:
|
|
||||||
validate_get_examples(get_examples, "SpanPredictor.initialize")
|
|
||||||
|
|
||||||
X = []
|
|
||||||
Y = []
|
|
||||||
for ex in islice(get_examples(), 2):
|
|
||||||
|
|
||||||
if not ex.predicted.spans:
|
|
||||||
# set placeholder for shape inference
|
|
||||||
doc = ex.predicted
|
|
||||||
assert len(doc) > 2, "Coreference requires at least two tokens"
|
|
||||||
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
|
||||||
X.append(ex.predicted)
|
|
||||||
Y.append(ex.reference)
|
|
||||||
|
|
||||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
|
||||||
self.model.initialize(X=X, Y=Y)
|
|
||||||
|
|
||||||
def score(self, examples, **kwargs):
|
|
||||||
"""
|
|
||||||
Evaluate on reconstructing the correct spans around
|
|
||||||
gold heads.
|
|
||||||
"""
|
|
||||||
scores = []
|
|
||||||
xp = self.model.ops.xp
|
|
||||||
for eg in examples:
|
|
||||||
starts = []
|
|
||||||
ends = []
|
|
||||||
pred_starts = []
|
|
||||||
pred_ends = []
|
|
||||||
ref = eg.reference
|
|
||||||
pred = eg.predicted
|
|
||||||
for key, gold_sg in ref.spans.items():
|
|
||||||
if key.startswith(self.output_prefix):
|
|
||||||
pred_sg = pred.spans[key]
|
|
||||||
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
|
||||||
starts.append(gold_mention.start)
|
|
||||||
ends.append(gold_mention.end)
|
|
||||||
pred_starts.append(pred_mention.start)
|
|
||||||
pred_ends.append(pred_mention.end)
|
|
||||||
|
|
||||||
starts = xp.asarray(starts)
|
|
||||||
ends = xp.asarray(ends)
|
|
||||||
pred_starts = xp.asarray(pred_starts)
|
|
||||||
pred_ends = xp.asarray(pred_ends)
|
|
||||||
correct = (starts == pred_starts) * (ends == pred_ends)
|
|
||||||
accuracy = correct.mean()
|
|
||||||
scores.append(float(accuracy))
|
|
||||||
return {"span_accuracy": mean(scores)}
|
|
||||||
|
|
280
spacy/pipeline/span_predictor.py
Normal file
280
spacy/pipeline/span_predictor.py
Normal file
|
@ -0,0 +1,280 @@
|
||||||
|
from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from thinc.types import Floats2d, Floats3d, Ints2d
|
||||||
|
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||||
|
from thinc.api import set_dropout_rate, to_categorical
|
||||||
|
from itertools import islice
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
from .trainable_pipe import TrainablePipe
|
||||||
|
from ..language import Language
|
||||||
|
from ..training import Example, validate_examples, validate_get_examples
|
||||||
|
from ..errors import Errors
|
||||||
|
from ..scorer import Scorer
|
||||||
|
from ..tokens import Doc
|
||||||
|
from ..vocab import Vocab
|
||||||
|
|
||||||
|
from ..ml.models.coref_util import (
|
||||||
|
MentionClusters,
|
||||||
|
DEFAULT_CLUSTER_PREFIX,
|
||||||
|
doc2clusters,
|
||||||
|
)
|
||||||
|
|
||||||
|
default_span_predictor_config = """
|
||||||
|
[model]
|
||||||
|
@architectures = "spacy.SpanPredictor.v1"
|
||||||
|
hidden_size = 1024
|
||||||
|
dist_emb_size = 64
|
||||||
|
|
||||||
|
[model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = 64
|
||||||
|
rows = [2000, 2000, 1000, 1000, 1000, 1000]
|
||||||
|
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[model.tok2vec.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = ${model.tok2vec.embed.width}
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
depth = 2
|
||||||
|
"""
|
||||||
|
DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"]
|
||||||
|
|
||||||
|
@Language.factory(
|
||||||
|
"span_predictor",
|
||||||
|
assigns=["doc.spans"],
|
||||||
|
requires=["doc.spans"],
|
||||||
|
default_config={
|
||||||
|
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
|
||||||
|
"input_prefix": "coref_head_clusters",
|
||||||
|
"output_prefix": "coref_clusters",
|
||||||
|
},
|
||||||
|
default_score_weights={"span_accuracy": 1.0},
|
||||||
|
)
|
||||||
|
def make_span_predictor(
|
||||||
|
nlp: Language,
|
||||||
|
name: str,
|
||||||
|
model,
|
||||||
|
input_prefix: str = "coref_head_clusters",
|
||||||
|
output_prefix: str = "coref_clusters",
|
||||||
|
) -> "SpanPredictor":
|
||||||
|
"""Create a SpanPredictor component."""
|
||||||
|
return SpanPredictor(nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix)
|
||||||
|
|
||||||
|
class SpanPredictor(TrainablePipe):
|
||||||
|
"""Pipeline component to resolve one-token spans to full spans.
|
||||||
|
|
||||||
|
Used in coreference resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab: Vocab,
|
||||||
|
model: Model,
|
||||||
|
name: str = "span_predictor",
|
||||||
|
*,
|
||||||
|
input_prefix: str = "coref_head_clusters",
|
||||||
|
output_prefix: str = "coref_clusters",
|
||||||
|
) -> None:
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = model
|
||||||
|
self.name = name
|
||||||
|
self.input_prefix = input_prefix
|
||||||
|
self.output_prefix = output_prefix
|
||||||
|
|
||||||
|
self.cfg = {}
|
||||||
|
|
||||||
|
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||||
|
# for now pretend there's just one doc
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for doc in docs:
|
||||||
|
# TODO check shape here
|
||||||
|
span_scores = self.model.predict([doc])
|
||||||
|
if span_scores.size:
|
||||||
|
# the information about clustering has to come from the input docs
|
||||||
|
# first let's convert the scores to a list of span idxs
|
||||||
|
start_scores = span_scores[:, :, 0]
|
||||||
|
end_scores = span_scores[:, :, 1]
|
||||||
|
starts = start_scores.argmax(axis=1)
|
||||||
|
ends = end_scores.argmax(axis=1)
|
||||||
|
|
||||||
|
# TODO check start < end
|
||||||
|
|
||||||
|
# get the old clusters (shape will be preserved)
|
||||||
|
clusters = doc2clusters(doc, self.input_prefix)
|
||||||
|
cidx = 0
|
||||||
|
out_clusters = []
|
||||||
|
for cluster in clusters:
|
||||||
|
ncluster = []
|
||||||
|
for mention in cluster:
|
||||||
|
ncluster.append((starts[cidx], ends[cidx]))
|
||||||
|
cidx += 1
|
||||||
|
out_clusters.append(ncluster)
|
||||||
|
else:
|
||||||
|
out_clusters = []
|
||||||
|
out.append(out_clusters)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||||
|
for doc, clusters in zip(docs, clusters_by_doc):
|
||||||
|
for ii, cluster in enumerate(clusters):
|
||||||
|
spans = [doc[mm[0]:mm[1]] for mm in cluster]
|
||||||
|
doc.spans[f"{self.output_prefix}_{ii}"] = spans
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Learn from a batch of documents and gold-standard information,
|
||||||
|
updating the pipe's model. Delegates to predict and get_loss.
|
||||||
|
"""
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
validate_examples(examples, "SpanPredictor.update")
|
||||||
|
if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
|
||||||
|
# Handle cases where there are no tokens in any docs.
|
||||||
|
return losses
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
for eg in examples:
|
||||||
|
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||||
|
# FIXME, this only happens once in the first 1000 docs of OntoNotes
|
||||||
|
# and I'm not sure yet why.
|
||||||
|
if span_scores.size:
|
||||||
|
loss, d_scores = self.get_loss([eg], span_scores)
|
||||||
|
total_loss += loss
|
||||||
|
# TODO check shape here
|
||||||
|
backprop((d_scores))
|
||||||
|
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
losses[self.name] += total_loss
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def rehearse(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
# TODO this should be added later
|
||||||
|
raise NotImplementedError(
|
||||||
|
Errors.E931.format(
|
||||||
|
parent="SpanPredictor", method="add_label", name=self.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_label(self, label: str) -> int:
|
||||||
|
"""Technically this method should be implemented from TrainablePipe,
|
||||||
|
but it is not relevant for this component.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
Errors.E931.format(
|
||||||
|
parent="SpanPredictor", method="add_label", name=self.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_loss(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
span_scores: Floats3d,
|
||||||
|
):
|
||||||
|
ops = self.model.ops
|
||||||
|
|
||||||
|
# NOTE This is doing fake batching, and should always get a list of one example
|
||||||
|
assert len(examples) == 1, "Only fake batching is supported."
|
||||||
|
# starts and ends are gold starts and ends (Ints1d)
|
||||||
|
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
||||||
|
for eg in examples:
|
||||||
|
starts = []
|
||||||
|
ends = []
|
||||||
|
for key, sg in eg.reference.spans.items():
|
||||||
|
if key.startswith(self.output_prefix):
|
||||||
|
for mention in sg:
|
||||||
|
starts.append(mention.start)
|
||||||
|
ends.append(mention.end)
|
||||||
|
|
||||||
|
starts = self.model.ops.xp.asarray(starts)
|
||||||
|
ends = self.model.ops.xp.asarray(ends)
|
||||||
|
start_scores = span_scores[:, :, 0]
|
||||||
|
end_scores = span_scores[:, :, 1]
|
||||||
|
n_classes = start_scores.shape[1]
|
||||||
|
start_probs = ops.softmax(start_scores, axis=1)
|
||||||
|
end_probs = ops.softmax(end_scores, axis=1)
|
||||||
|
start_targets = to_categorical(starts, n_classes)
|
||||||
|
end_targets = to_categorical(ends, n_classes)
|
||||||
|
start_grads = (start_probs - start_targets)
|
||||||
|
end_grads = (end_probs - end_targets)
|
||||||
|
grads = ops.xp.stack((start_grads, end_grads), axis=2)
|
||||||
|
loss = float((grads ** 2).sum())
|
||||||
|
return loss, grads
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
get_examples: Callable[[], Iterable[Example]],
|
||||||
|
*,
|
||||||
|
nlp: Optional[Language] = None,
|
||||||
|
) -> None:
|
||||||
|
validate_get_examples(get_examples, "SpanPredictor.initialize")
|
||||||
|
|
||||||
|
X = []
|
||||||
|
Y = []
|
||||||
|
for ex in islice(get_examples(), 2):
|
||||||
|
|
||||||
|
if not ex.predicted.spans:
|
||||||
|
# set placeholder for shape inference
|
||||||
|
doc = ex.predicted
|
||||||
|
assert len(doc) > 2, "Coreference requires at least two tokens"
|
||||||
|
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
||||||
|
X.append(ex.predicted)
|
||||||
|
Y.append(ex.reference)
|
||||||
|
|
||||||
|
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
|
def score(self, examples, **kwargs):
|
||||||
|
"""
|
||||||
|
Evaluate on reconstructing the correct spans around
|
||||||
|
gold heads.
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
xp = self.model.ops.xp
|
||||||
|
for eg in examples:
|
||||||
|
starts = []
|
||||||
|
ends = []
|
||||||
|
pred_starts = []
|
||||||
|
pred_ends = []
|
||||||
|
ref = eg.reference
|
||||||
|
pred = eg.predicted
|
||||||
|
for key, gold_sg in ref.spans.items():
|
||||||
|
if key.startswith(self.output_prefix):
|
||||||
|
pred_sg = pred.spans[key]
|
||||||
|
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||||
|
starts.append(gold_mention.start)
|
||||||
|
ends.append(gold_mention.end)
|
||||||
|
pred_starts.append(pred_mention.start)
|
||||||
|
pred_ends.append(pred_mention.end)
|
||||||
|
|
||||||
|
starts = xp.asarray(starts)
|
||||||
|
ends = xp.asarray(ends)
|
||||||
|
pred_starts = xp.asarray(pred_starts)
|
||||||
|
pred_ends = xp.asarray(pred_ends)
|
||||||
|
correct = (starts == pred_starts) * (ends == pred_ends)
|
||||||
|
accuracy = correct.mean()
|
||||||
|
scores.append(float(accuracy))
|
||||||
|
return {"span_accuracy": mean(scores)}
|
Loading…
Reference in New Issue
Block a user