mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Add progress on SpanPredictor component
This isn't working. There is a CUDA error in the torch code during initialization and it's not clear why.
This commit is contained in:
parent
a098849112
commit
2190cbc0e6
|
@ -14,7 +14,7 @@ from ..extract_spans import extract_spans
|
||||||
import torch
|
import torch
|
||||||
from thinc.util import xp2torch, torch2xp
|
from thinc.util import xp2torch, torch2xp
|
||||||
|
|
||||||
from .coref_util import add_dummy
|
from .coref_util import add_dummy, get_sentence_ids
|
||||||
|
|
||||||
@registry.architectures("spacy.Coref.v1")
|
@registry.architectures("spacy.Coref.v1")
|
||||||
def build_wl_coref_model(
|
def build_wl_coref_model(
|
||||||
|
@ -74,6 +74,33 @@ def build_wl_coref_model(
|
||||||
# and just return words as spans.
|
# and just return words as spans.
|
||||||
return coref_model
|
return coref_model
|
||||||
|
|
||||||
|
@registry.architectures("spacy.SpanPredictor.v1")
|
||||||
|
def build_span_predictor(
|
||||||
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
dist_emb_size: int = 64,
|
||||||
|
):
|
||||||
|
# TODO fix this
|
||||||
|
try:
|
||||||
|
dim = tok2vec.get_dim("nO")
|
||||||
|
except ValueError:
|
||||||
|
# happens with transformer listener
|
||||||
|
dim = 768
|
||||||
|
|
||||||
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||||
|
# TODO fix device - should be automatic
|
||||||
|
device = "cuda:0"
|
||||||
|
span_predictor = PyTorchWrapper(
|
||||||
|
SpanPredictor(hidden_size, dist_emb_size, device),
|
||||||
|
convert_inputs=convert_span_predictor_inputs
|
||||||
|
)
|
||||||
|
# TODO use proper parameter for prefix
|
||||||
|
head_info = build_get_head_metadata("coref_head_clusters")
|
||||||
|
model = (tok2vec & head_info) >> span_predictor
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_scorer_inputs(
|
def convert_coref_scorer_inputs(
|
||||||
model: Model,
|
model: Model,
|
||||||
X: List[Floats2d],
|
X: List[Floats2d],
|
||||||
|
@ -84,6 +111,7 @@ def convert_coref_scorer_inputs(
|
||||||
# TODO real batching
|
# TODO real batching
|
||||||
X = X[0]
|
X = X[0]
|
||||||
|
|
||||||
|
|
||||||
word_features = xp2torch(X, requires_grad=is_train)
|
word_features = xp2torch(X, requires_grad=is_train)
|
||||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
# convert to xp and wrap in list
|
# convert to xp and wrap in list
|
||||||
|
@ -116,10 +144,15 @@ def convert_span_predictor_inputs(
|
||||||
X: Tuple[Ints1d, Floats2d, Ints1d],
|
X: Tuple[Ints1d, Floats2d, Ints1d],
|
||||||
is_train: bool
|
is_train: bool
|
||||||
):
|
):
|
||||||
sent_id = xp2torch(X[0], requires_grad=False)
|
tok2vec, (sent_ids, head_ids) = X
|
||||||
word_features = xp2torch(X[1], requires_grad=False)
|
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||||
head_ids = xp2torch(X[2], requires_grad=False)
|
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||||
argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={})
|
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||||
|
|
||||||
|
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||||
|
|
||||||
|
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||||
|
# TODO actually support backprop
|
||||||
return argskwargs, lambda dX: []
|
return argskwargs, lambda dX: []
|
||||||
|
|
||||||
# TODO This probably belongs in the component, not the model.
|
# TODO This probably belongs in the component, not the model.
|
||||||
|
@ -189,6 +222,36 @@ def _clusterize(
|
||||||
clusters.append(sorted(cluster))
|
clusters.append(sorted(cluster))
|
||||||
return sorted(clusters)
|
return sorted(clusters)
|
||||||
|
|
||||||
|
def build_get_head_metadata(prefix):
|
||||||
|
# TODO this name is awful, fix it
|
||||||
|
model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def head_data_forward(model, docs, is_train):
|
||||||
|
"""A layer to generate the extra data needed for the span predictor.
|
||||||
|
"""
|
||||||
|
sent_ids = []
|
||||||
|
head_ids = []
|
||||||
|
prefix = model.attrs["prefix"]
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
sids = model.ops.asarray2i(get_sentence_ids(doc))
|
||||||
|
sent_ids.append(sids)
|
||||||
|
heads = []
|
||||||
|
for key, sg in doc.spans.items():
|
||||||
|
if not key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
for span in sg:
|
||||||
|
# TODO warn if spans are more than one token
|
||||||
|
heads.append(span[0].i)
|
||||||
|
heads = model.ops.asarray2i(heads)
|
||||||
|
head_ids.append(heads)
|
||||||
|
|
||||||
|
# each of these is a list with one entry per doc
|
||||||
|
# backprop is just a placeholder
|
||||||
|
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
||||||
|
return (sent_ids, head_ids), lambda x: []
|
||||||
|
|
||||||
|
|
||||||
class CorefScorer(torch.nn.Module):
|
class CorefScorer(torch.nn.Module):
|
||||||
"""Combines all coref modules together to find coreferent spans.
|
"""Combines all coref modules together to find coreferent spans.
|
||||||
|
@ -492,6 +555,7 @@ class SpanPredictor(torch.nn.Module):
|
||||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
||||||
sent_id = torch.tensor(sent_id, device=words.device)
|
sent_id = torch.tensor(sent_id, device=words.device)
|
||||||
|
heads_ids = heads_ids.long()
|
||||||
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
||||||
|
|
||||||
# To save memory, only pass candidates from one sentence for each head
|
# To save memory, only pass candidates from one sentence for each head
|
||||||
|
@ -506,7 +570,7 @@ class SpanPredictor(torch.nn.Module):
|
||||||
), dim=1)
|
), dim=1)
|
||||||
|
|
||||||
lengths = same_sent.sum(dim=1)
|
lengths = same_sent.sum(dim=1)
|
||||||
padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)
|
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
|
||||||
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
||||||
|
|
||||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
||||||
|
|
|
@ -39,6 +39,15 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
||||||
output = torch.cat((dummy, tensor), dim=1)
|
output = torch.cat((dummy, tensor), dim=1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def get_sentence_ids(doc):
|
||||||
|
out = []
|
||||||
|
sent_id = -1
|
||||||
|
for tok in doc:
|
||||||
|
if tok.is_sent_start:
|
||||||
|
sent_id += 1
|
||||||
|
out.append(sent_id)
|
||||||
|
return out
|
||||||
|
|
||||||
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
||||||
"""Given a doc, give the mention clusters.
|
"""Given a doc, give the mention clusters.
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
|
from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from thinc.types import Floats2d, Ints2d
|
from thinc.types import Floats2d, Floats3d, Ints2d
|
||||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||||
from thinc.api import set_dropout_rate
|
from thinc.api import set_dropout_rate
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
@ -84,6 +84,7 @@ def make_coref(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CoreferenceResolver(TrainablePipe):
|
class CoreferenceResolver(TrainablePipe):
|
||||||
"""Pipeline component for coreference resolution.
|
"""Pipeline component for coreference resolution.
|
||||||
|
|
||||||
|
@ -208,7 +209,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
# TODO does this even work?
|
# TODO check this causes no issues (in practice it runs)
|
||||||
preds, backprop = self.model.begin_update([eg.predicted])
|
preds, backprop = self.model.begin_update([eg.predicted])
|
||||||
score_matrix, mention_idx = preds
|
score_matrix, mention_idx = preds
|
||||||
|
|
||||||
|
@ -384,6 +385,52 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
default_span_predictor_config = """
|
||||||
|
[model]
|
||||||
|
@architectures = "spacy.SpanPredictor.v1"
|
||||||
|
hidden_size = 1024
|
||||||
|
dist_emb_size = 64
|
||||||
|
|
||||||
|
[model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = 64
|
||||||
|
rows = [2000, 2000, 1000, 1000, 1000, 1000]
|
||||||
|
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[model.tok2vec.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = ${model.tok2vec.embed.width}
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
depth = 2
|
||||||
|
"""
|
||||||
|
DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"]
|
||||||
|
|
||||||
|
@Language.factory(
|
||||||
|
"span_predictor",
|
||||||
|
assigns=["doc.spans"],
|
||||||
|
requires=["doc.spans"],
|
||||||
|
default_config={
|
||||||
|
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
|
||||||
|
"input_prefix": "coref_head_clusters",
|
||||||
|
"output_prefix": "coref_clusters",
|
||||||
|
},
|
||||||
|
default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None},
|
||||||
|
)
|
||||||
|
def make_span_predictor(
|
||||||
|
nlp: Language,
|
||||||
|
name: str,
|
||||||
|
model,
|
||||||
|
input_prefix: str = "coref_head_clusters",
|
||||||
|
output_prefix: str = "coref_clusters",
|
||||||
|
) -> "SpanPredictor":
|
||||||
|
"""Create a SpanPredictor component."""
|
||||||
|
return SpanPredictor(nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix)
|
||||||
|
|
||||||
class SpanPredictor(TrainablePipe):
|
class SpanPredictor(TrainablePipe):
|
||||||
"""Pipeline component to resolve one-token spans to full spans.
|
"""Pipeline component to resolve one-token spans to full spans.
|
||||||
|
|
||||||
|
@ -407,11 +454,41 @@ class SpanPredictor(TrainablePipe):
|
||||||
|
|
||||||
self.cfg = {}
|
self.cfg = {}
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]):
|
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||||
...
|
# for now pretend there's just one doc
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for doc in docs:
|
||||||
|
# TODO check shape here
|
||||||
|
span_scores = self.model.predict(doc)
|
||||||
|
span_scores = span_scores[0]
|
||||||
|
# the information about clustering has to come from the input docs
|
||||||
|
# first let's convert the scores to a list of span idxs
|
||||||
|
start_scores = span_scores[:, :, 0]
|
||||||
|
end_scores = span_scores[:, :, 1]
|
||||||
|
starts = start_scores.argmax(axis=1)
|
||||||
|
ends = end_scores.argmax(axis=1)
|
||||||
|
|
||||||
|
# TODO check start < end
|
||||||
|
|
||||||
|
# get the old clusters (shape will be preserved)
|
||||||
|
clusters = doc2clusters(doc, self.input_prefix)
|
||||||
|
cidx = 0
|
||||||
|
out_clusters = []
|
||||||
|
for cluster in clusters:
|
||||||
|
ncluster = []
|
||||||
|
for mention in cluster:
|
||||||
|
ncluster.append( (starts[cidx], ends[cidx]) )
|
||||||
|
cidx += 1
|
||||||
|
out_clusters.append(ncluster)
|
||||||
|
out.append(out_clusters)
|
||||||
|
return out
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||||
...
|
for doc, clusters in zip(docs, clusters_by_doc):
|
||||||
|
for ii, cluster in enumerate(clusters):
|
||||||
|
spans = [doc[mm[0]:mm[1]] for mm in cluster]
|
||||||
|
doc.spans[f"{self.output_prefix}_{ii}"] = spans
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
|
@ -421,7 +498,33 @@ class SpanPredictor(TrainablePipe):
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
...
|
"""Learn from a batch of documents and gold-standard information,
|
||||||
|
updating the pipe's model. Delegates to predict and get_loss.
|
||||||
|
"""
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
validate_examples(examples, "SpanPredictor.update")
|
||||||
|
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||||
|
# Handle cases where there are no tokens in any docs.
|
||||||
|
return losses
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
|
||||||
|
for eg in examples:
|
||||||
|
preds, backprop = self.model.begin_update([eg.predicted])
|
||||||
|
score_matrix, mention_idx = preds
|
||||||
|
|
||||||
|
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||||
|
total_loss += loss
|
||||||
|
# TODO check shape here
|
||||||
|
backprop((d_scores, mention_idx))
|
||||||
|
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
losses[self.name] += total_loss
|
||||||
|
return losses
|
||||||
|
|
||||||
def rehearse(
|
def rehearse(
|
||||||
self,
|
self,
|
||||||
|
@ -431,7 +534,12 @@ class SpanPredictor(TrainablePipe):
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
...
|
# TODO this should be added later
|
||||||
|
raise NotImplementedError(
|
||||||
|
Errors.E931.format(
|
||||||
|
parent="SpanPredictor", method="add_label", name=self.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def add_label(self, label: str) -> int:
|
def add_label(self, label: str) -> int:
|
||||||
"""Technically this method should be implemented from TrainablePipe,
|
"""Technically this method should be implemented from TrainablePipe,
|
||||||
|
@ -446,9 +554,39 @@ class SpanPredictor(TrainablePipe):
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self,
|
self,
|
||||||
examples: Iterable[Example],
|
examples: Iterable[Example],
|
||||||
# TODO add necessary args
|
span_scores: Floats3d,
|
||||||
):
|
):
|
||||||
...
|
ops = self.model.ops
|
||||||
|
|
||||||
|
# NOTE This is doing fake batching, and should always get a list of one example
|
||||||
|
assert len(examples) == 1, "Only fake batching is supported."
|
||||||
|
# starts and ends are gold starts and ends (Ints1d)
|
||||||
|
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
||||||
|
|
||||||
|
for eg in examples:
|
||||||
|
|
||||||
|
# get gold data
|
||||||
|
gold = doc2clusters(eg.reference, self.output_prefix)
|
||||||
|
# flatten the gold data
|
||||||
|
starts = []
|
||||||
|
ends = []
|
||||||
|
for cluster in gold:
|
||||||
|
for mention in cluster:
|
||||||
|
starts.append(mention[0])
|
||||||
|
ends.append(mention[1])
|
||||||
|
|
||||||
|
start_scores = span_scores[:, :, 0]
|
||||||
|
end_scores = span_scores[:, :, 1]
|
||||||
|
n_classes = start_scores.shape[1]
|
||||||
|
start_probs = ops.softmax(start_scores, axis=1)
|
||||||
|
end_probs = ops.softmax(end_scores, axis=1)
|
||||||
|
start_targets = to_categorical(starts, n_classes)
|
||||||
|
end_targets = to_categorical(ends, n_classes)
|
||||||
|
start_grads = (start_probs - start_targets)
|
||||||
|
end_grads = (end_probs - end_targets)
|
||||||
|
grads = ops.xp.stack((start_grads, end_grads), axis=2)
|
||||||
|
loss = float((grads ** 2).sum())
|
||||||
|
return loss, grads
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
self,
|
self,
|
||||||
|
@ -461,6 +599,12 @@ class SpanPredictor(TrainablePipe):
|
||||||
X = []
|
X = []
|
||||||
Y = []
|
Y = []
|
||||||
for ex in islice(get_examples(), 2):
|
for ex in islice(get_examples(), 2):
|
||||||
|
|
||||||
|
if not ex.predicted.spans:
|
||||||
|
# set placeholder for shape inference
|
||||||
|
doc = ex.predicted
|
||||||
|
assert len(doc) > 2, "Coreference requires at least two tokens"
|
||||||
|
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
||||||
X.append(ex.predicted)
|
X.append(ex.predicted)
|
||||||
Y.append(ex.reference)
|
Y.append(ex.reference)
|
||||||
|
|
||||||
|
@ -468,5 +612,31 @@ class SpanPredictor(TrainablePipe):
|
||||||
self.model.initialize(X=X, Y=Y)
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
# TODO this will overlap significantly with coref, maybe factor into function
|
"""Score a batch of examples."""
|
||||||
...
|
# TODO This is basically the same as the main coref component - factor out?
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for metric in (b_cubed, muc, ceafe):
|
||||||
|
evaluator = Evaluator(metric)
|
||||||
|
|
||||||
|
for ex in examples:
|
||||||
|
# XXX this is the only different part
|
||||||
|
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
||||||
|
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
||||||
|
|
||||||
|
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||||
|
|
||||||
|
evaluator.update(cluster_info)
|
||||||
|
|
||||||
|
score = {
|
||||||
|
"coref_f": evaluator.get_f1(),
|
||||||
|
"coref_p": evaluator.get_precision(),
|
||||||
|
"coref_r": evaluator.get_recall(),
|
||||||
|
}
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for field in ("f", "p", "r"):
|
||||||
|
fname = f"coref_{field}"
|
||||||
|
out[fname] = mean([ss[fname] for ss in scores])
|
||||||
|
return out
|
||||||
|
|
Loading…
Reference in New Issue
Block a user