Add progress on SpanPredictor component

This isn't working. There is a CUDA error in the torch code during
initialization and it's not clear why.
This commit is contained in:
Paul O'Leary McCann 2022-03-19 19:39:49 +09:00
parent a098849112
commit 2190cbc0e6
3 changed files with 260 additions and 17 deletions

View File

@ -14,7 +14,7 @@ from ..extract_spans import extract_spans
import torch import torch
from thinc.util import xp2torch, torch2xp from thinc.util import xp2torch, torch2xp
from .coref_util import add_dummy from .coref_util import add_dummy, get_sentence_ids
@registry.architectures("spacy.Coref.v1") @registry.architectures("spacy.Coref.v1")
def build_wl_coref_model( def build_wl_coref_model(
@ -74,6 +74,33 @@ def build_wl_coref_model(
# and just return words as spans. # and just return words as spans.
return coref_model return coref_model
@registry.architectures("spacy.SpanPredictor.v1")
def build_span_predictor(
tok2vec: Model[List[Doc], List[Floats2d]],
hidden_size: int = 1024,
dist_emb_size: int = 64,
):
# TODO fix this
try:
dim = tok2vec.get_dim("nO")
except ValueError:
# happens with transformer listener
dim = 768
with Model.define_operators({">>": chain, "&": tuplify}):
# TODO fix device - should be automatic
device = "cuda:0"
span_predictor = PyTorchWrapper(
SpanPredictor(hidden_size, dist_emb_size, device),
convert_inputs=convert_span_predictor_inputs
)
# TODO use proper parameter for prefix
head_info = build_get_head_metadata("coref_head_clusters")
model = (tok2vec & head_info) >> span_predictor
return model
def convert_coref_scorer_inputs( def convert_coref_scorer_inputs(
model: Model, model: Model,
X: List[Floats2d], X: List[Floats2d],
@ -84,6 +111,7 @@ def convert_coref_scorer_inputs(
# TODO real batching # TODO real batching
X = X[0] X = X[0]
word_features = xp2torch(X, requires_grad=is_train) word_features = xp2torch(X, requires_grad=is_train)
def backprop(args: ArgsKwargs) -> List[Floats2d]: def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list # convert to xp and wrap in list
@ -116,10 +144,15 @@ def convert_span_predictor_inputs(
X: Tuple[Ints1d, Floats2d, Ints1d], X: Tuple[Ints1d, Floats2d, Ints1d],
is_train: bool is_train: bool
): ):
sent_id = xp2torch(X[0], requires_grad=False) tok2vec, (sent_ids, head_ids) = X
word_features = xp2torch(X[1], requires_grad=False) # Normally we shoudl use the input is_train, but for these two it's not relevant
head_ids = xp2torch(X[2], requires_grad=False) sent_ids = xp2torch(sent_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={}) head_ids = xp2torch(head_ids[0], requires_grad=False)
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop
return argskwargs, lambda dX: [] return argskwargs, lambda dX: []
# TODO This probably belongs in the component, not the model. # TODO This probably belongs in the component, not the model.
@ -189,6 +222,36 @@ def _clusterize(
clusters.append(sorted(cluster)) clusters.append(sorted(cluster))
return sorted(clusters) return sorted(clusters)
def build_get_head_metadata(prefix):
# TODO this name is awful, fix it
model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward)
return model
def head_data_forward(model, docs, is_train):
"""A layer to generate the extra data needed for the span predictor.
"""
sent_ids = []
head_ids = []
prefix = model.attrs["prefix"]
for doc in docs:
sids = model.ops.asarray2i(get_sentence_ids(doc))
sent_ids.append(sids)
heads = []
for key, sg in doc.spans.items():
if not key.startswith(prefix):
continue
for span in sg:
# TODO warn if spans are more than one token
heads.append(span[0].i)
heads = model.ops.asarray2i(heads)
head_ids.append(heads)
# each of these is a list with one entry per doc
# backprop is just a placeholder
# TODO it would probably be better to have a list of tuples than two lists of arrays
return (sent_ids, head_ids), lambda x: []
class CorefScorer(torch.nn.Module): class CorefScorer(torch.nn.Module):
"""Combines all coref modules together to find coreferent spans. """Combines all coref modules together to find coreferent spans.
@ -492,6 +555,7 @@ class SpanPredictor(torch.nn.Module):
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words] # Obtain "same sentence" boolean mask, [n_heads, n_words]
sent_id = torch.tensor(sent_id, device=words.device) sent_id = torch.tensor(sent_id, device=words.device)
heads_ids = heads_ids.long()
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head # To save memory, only pass candidates from one sentence for each head
@ -506,7 +570,7 @@ class SpanPredictor(torch.nn.Module):
), dim=1) ), dim=1)
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0) padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]

View File

@ -39,6 +39,15 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False):
output = torch.cat((dummy, tensor), dim=1) output = torch.cat((dummy, tensor), dim=1)
return output return output
def get_sentence_ids(doc):
out = []
sent_id = -1
for tok in doc:
if tok.is_sent_start:
sent_id += 1
out.append(sent_id)
return out
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters: def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
"""Given a doc, give the mention clusters. """Given a doc, give the mention clusters.

View File

@ -1,7 +1,7 @@
from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
import warnings import warnings
from thinc.types import Floats2d, Ints2d from thinc.types import Floats2d, Floats3d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate
from itertools import islice from itertools import islice
@ -84,6 +84,7 @@ def make_coref(
) )
class CoreferenceResolver(TrainablePipe): class CoreferenceResolver(TrainablePipe):
"""Pipeline component for coreference resolution. """Pipeline component for coreference resolution.
@ -208,7 +209,7 @@ class CoreferenceResolver(TrainablePipe):
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
# TODO does this even work? # TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted]) preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds score_matrix, mention_idx = preds
@ -384,6 +385,52 @@ class CoreferenceResolver(TrainablePipe):
return out return out
default_span_predictor_config = """
[model]
@architectures = "spacy.SpanPredictor.v1"
hidden_size = 1024
dist_emb_size = 64
[model.tok2vec]
@architectures = "spacy.Tok2Vec.v2"
[model.tok2vec.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = 64
rows = [2000, 2000, 1000, 1000, 1000, 1000]
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
include_static_vectors = false
[model.tok2vec.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = ${model.tok2vec.embed.width}
window_size = 1
maxout_pieces = 3
depth = 2
"""
DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"]
@Language.factory(
"span_predictor",
assigns=["doc.spans"],
requires=["doc.spans"],
default_config={
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
"input_prefix": "coref_head_clusters",
"output_prefix": "coref_clusters",
},
default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None},
)
def make_span_predictor(
nlp: Language,
name: str,
model,
input_prefix: str = "coref_head_clusters",
output_prefix: str = "coref_clusters",
) -> "SpanPredictor":
"""Create a SpanPredictor component."""
return SpanPredictor(nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix)
class SpanPredictor(TrainablePipe): class SpanPredictor(TrainablePipe):
"""Pipeline component to resolve one-token spans to full spans. """Pipeline component to resolve one-token spans to full spans.
@ -407,11 +454,41 @@ class SpanPredictor(TrainablePipe):
self.cfg = {} self.cfg = {}
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
... # for now pretend there's just one doc
out = []
for doc in docs:
# TODO check shape here
span_scores = self.model.predict(doc)
span_scores = span_scores[0]
# the information about clustering has to come from the input docs
# first let's convert the scores to a list of span idxs
start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1]
starts = start_scores.argmax(axis=1)
ends = end_scores.argmax(axis=1)
# TODO check start < end
# get the old clusters (shape will be preserved)
clusters = doc2clusters(doc, self.input_prefix)
cidx = 0
out_clusters = []
for cluster in clusters:
ncluster = []
for mention in cluster:
ncluster.append( (starts[cidx], ends[cidx]) )
cidx += 1
out_clusters.append(ncluster)
out.append(out_clusters)
return out
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
... for doc, clusters in zip(docs, clusters_by_doc):
for ii, cluster in enumerate(clusters):
spans = [doc[mm[0]:mm[1]] for mm in cluster]
doc.spans[f"{self.output_prefix}_{ii}"] = spans
def update( def update(
self, self,
@ -421,7 +498,33 @@ class SpanPredictor(TrainablePipe):
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
... """Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
"""
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "SpanPredictor.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return losses
set_dropout_rate(self.model, drop)
total_loss = 0
for eg in examples:
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss
# TODO check shape here
backprop((d_scores, mention_idx))
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += total_loss
return losses
def rehearse( def rehearse(
self, self,
@ -431,7 +534,12 @@ class SpanPredictor(TrainablePipe):
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
... # TODO this should be added later
raise NotImplementedError(
Errors.E931.format(
parent="SpanPredictor", method="add_label", name=self.name
)
)
def add_label(self, label: str) -> int: def add_label(self, label: str) -> int:
"""Technically this method should be implemented from TrainablePipe, """Technically this method should be implemented from TrainablePipe,
@ -446,9 +554,39 @@ class SpanPredictor(TrainablePipe):
def get_loss( def get_loss(
self, self,
examples: Iterable[Example], examples: Iterable[Example],
# TODO add necessary args span_scores: Floats3d,
): ):
... ops = self.model.ops
# NOTE This is doing fake batching, and should always get a list of one example
assert len(examples) == 1, "Only fake batching is supported."
# starts and ends are gold starts and ends (Ints1d)
# span_scores is a Floats3d. What are the axes? mention x token x start/end
for eg in examples:
# get gold data
gold = doc2clusters(eg.reference, self.output_prefix)
# flatten the gold data
starts = []
ends = []
for cluster in gold:
for mention in cluster:
starts.append(mention[0])
ends.append(mention[1])
start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1]
n_classes = start_scores.shape[1]
start_probs = ops.softmax(start_scores, axis=1)
end_probs = ops.softmax(end_scores, axis=1)
start_targets = to_categorical(starts, n_classes)
end_targets = to_categorical(ends, n_classes)
start_grads = (start_probs - start_targets)
end_grads = (end_probs - end_targets)
grads = ops.xp.stack((start_grads, end_grads), axis=2)
loss = float((grads ** 2).sum())
return loss, grads
def initialize( def initialize(
self, self,
@ -461,6 +599,12 @@ class SpanPredictor(TrainablePipe):
X = [] X = []
Y = [] Y = []
for ex in islice(get_examples(), 2): for ex in islice(get_examples(), 2):
if not ex.predicted.spans:
# set placeholder for shape inference
doc = ex.predicted
assert len(doc) > 2, "Coreference requires at least two tokens"
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
X.append(ex.predicted) X.append(ex.predicted)
Y.append(ex.reference) Y.append(ex.reference)
@ -468,5 +612,31 @@ class SpanPredictor(TrainablePipe):
self.model.initialize(X=X, Y=Y) self.model.initialize(X=X, Y=Y)
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
# TODO this will overlap significantly with coref, maybe factor into function """Score a batch of examples."""
... # TODO This is basically the same as the main coref component - factor out?
scores = []
for metric in (b_cubed, muc, ceafe):
evaluator = Evaluator(metric)
for ex in examples:
# XXX this is the only different part
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
g_clusters = doc2clusters(ex.reference, self.output_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info)
score = {
"coref_f": evaluator.get_f1(),
"coref_p": evaluator.get_precision(),
"coref_r": evaluator.get_recall(),
}
scores.append(score)
out = {}
for field in ("f", "p", "r"):
fname = f"coref_{field}"
out[fname] = mean([ss[fname] for ss in scores])
return out