diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 049f0efae..382d7a98b 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -14,7 +14,7 @@ from ..extract_spans import extract_spans import torch from thinc.util import xp2torch, torch2xp -from .coref_util import add_dummy +from .coref_util import add_dummy, get_sentence_ids @registry.architectures("spacy.Coref.v1") def build_wl_coref_model( @@ -74,6 +74,33 @@ def build_wl_coref_model( # and just return words as spans. return coref_model +@registry.architectures("spacy.SpanPredictor.v1") +def build_span_predictor( + tok2vec: Model[List[Doc], List[Floats2d]], + hidden_size: int = 1024, + dist_emb_size: int = 64, + ): + # TODO fix this + try: + dim = tok2vec.get_dim("nO") + except ValueError: + # happens with transformer listener + dim = 768 + + with Model.define_operators({">>": chain, "&": tuplify}): + # TODO fix device - should be automatic + device = "cuda:0" + span_predictor = PyTorchWrapper( + SpanPredictor(hidden_size, dist_emb_size, device), + convert_inputs=convert_span_predictor_inputs + ) + # TODO use proper parameter for prefix + head_info = build_get_head_metadata("coref_head_clusters") + model = (tok2vec & head_info) >> span_predictor + + return model + + def convert_coref_scorer_inputs( model: Model, X: List[Floats2d], @@ -84,6 +111,7 @@ def convert_coref_scorer_inputs( # TODO real batching X = X[0] + word_features = xp2torch(X, requires_grad=is_train) def backprop(args: ArgsKwargs) -> List[Floats2d]: # convert to xp and wrap in list @@ -116,10 +144,15 @@ def convert_span_predictor_inputs( X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool ): - sent_id = xp2torch(X[0], requires_grad=False) - word_features = xp2torch(X[1], requires_grad=False) - head_ids = xp2torch(X[2], requires_grad=False) - argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={}) + tok2vec, (sent_ids, head_ids) = X + # Normally we shoudl use the input is_train, but for these two it's not relevant + sent_ids = xp2torch(sent_ids[0], requires_grad=False) + head_ids = xp2torch(head_ids[0], requires_grad=False) + + word_features = xp2torch(tok2vec[0], requires_grad=is_train) + + argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) + # TODO actually support backprop return argskwargs, lambda dX: [] # TODO This probably belongs in the component, not the model. @@ -189,6 +222,36 @@ def _clusterize( clusters.append(sorted(cluster)) return sorted(clusters) +def build_get_head_metadata(prefix): + # TODO this name is awful, fix it + model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward) + return model + +def head_data_forward(model, docs, is_train): + """A layer to generate the extra data needed for the span predictor. + """ + sent_ids = [] + head_ids = [] + prefix = model.attrs["prefix"] + + for doc in docs: + sids = model.ops.asarray2i(get_sentence_ids(doc)) + sent_ids.append(sids) + heads = [] + for key, sg in doc.spans.items(): + if not key.startswith(prefix): + continue + for span in sg: + # TODO warn if spans are more than one token + heads.append(span[0].i) + heads = model.ops.asarray2i(heads) + head_ids.append(heads) + + # each of these is a list with one entry per doc + # backprop is just a placeholder + # TODO it would probably be better to have a list of tuples than two lists of arrays + return (sent_ids, head_ids), lambda x: [] + class CorefScorer(torch.nn.Module): """Combines all coref modules together to find coreferent spans. @@ -492,6 +555,7 @@ class SpanPredictor(torch.nn.Module): emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # Obtain "same sentence" boolean mask, [n_heads, n_words] sent_id = torch.tensor(sent_id, device=words.device) + heads_ids = heads_ids.long() same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) # To save memory, only pass candidates from one sentence for each head @@ -506,7 +570,7 @@ class SpanPredictor(torch.nn.Module): ), dim=1) lengths = same_sent.sum(dim=1) - padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0) + padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0) padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size] diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index c75314fa6..e8de1e0ac 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -39,6 +39,15 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False): output = torch.cat((dummy, tensor), dim=1) return output +def get_sentence_ids(doc): + out = [] + sent_id = -1 + for tok in doc: + if tok.is_sent_start: + sent_id += 1 + out.append(sent_id) + return out + def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters: """Given a doc, give the mention clusters. diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 4b1483e3c..54e9d8cfd 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -1,7 +1,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List import warnings -from thinc.types import Floats2d, Ints2d +from thinc.types import Floats2d, Floats3d, Ints2d from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from thinc.api import set_dropout_rate from itertools import islice @@ -84,6 +84,7 @@ def make_coref( ) + class CoreferenceResolver(TrainablePipe): """Pipeline component for coreference resolution. @@ -208,7 +209,7 @@ class CoreferenceResolver(TrainablePipe): total_loss = 0 for eg in examples: - # TODO does this even work? + # TODO check this causes no issues (in practice it runs) preds, backprop = self.model.begin_update([eg.predicted]) score_matrix, mention_idx = preds @@ -384,6 +385,52 @@ class CoreferenceResolver(TrainablePipe): return out +default_span_predictor_config = """ +[model] +@architectures = "spacy.SpanPredictor.v1" +hidden_size = 1024 +dist_emb_size = 64 + +[model.tok2vec] +@architectures = "spacy.Tok2Vec.v2" + +[model.tok2vec.embed] +@architectures = "spacy.MultiHashEmbed.v1" +width = 64 +rows = [2000, 2000, 1000, 1000, 1000, 1000] +attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"] +include_static_vectors = false + +[model.tok2vec.encode] +@architectures = "spacy.MaxoutWindowEncoder.v2" +width = ${model.tok2vec.embed.width} +window_size = 1 +maxout_pieces = 3 +depth = 2 +""" +DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"] + +@Language.factory( + "span_predictor", + assigns=["doc.spans"], + requires=["doc.spans"], + default_config={ + "model": DEFAULT_SPAN_PREDICTOR_MODEL, + "input_prefix": "coref_head_clusters", + "output_prefix": "coref_clusters", + }, + default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None}, + ) +def make_span_predictor( + nlp: Language, + name: str, + model, + input_prefix: str = "coref_head_clusters", + output_prefix: str = "coref_clusters", +) -> "SpanPredictor": + """Create a SpanPredictor component.""" + return SpanPredictor(nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix) + class SpanPredictor(TrainablePipe): """Pipeline component to resolve one-token spans to full spans. @@ -407,11 +454,41 @@ class SpanPredictor(TrainablePipe): self.cfg = {} - def predict(self, docs: Iterable[Doc]): - ... + def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: + # for now pretend there's just one doc + + out = [] + for doc in docs: + # TODO check shape here + span_scores = self.model.predict(doc) + span_scores = span_scores[0] + # the information about clustering has to come from the input docs + # first let's convert the scores to a list of span idxs + start_scores = span_scores[:, :, 0] + end_scores = span_scores[:, :, 1] + starts = start_scores.argmax(axis=1) + ends = end_scores.argmax(axis=1) + + # TODO check start < end + + # get the old clusters (shape will be preserved) + clusters = doc2clusters(doc, self.input_prefix) + cidx = 0 + out_clusters = [] + for cluster in clusters: + ncluster = [] + for mention in cluster: + ncluster.append( (starts[cidx], ends[cidx]) ) + cidx += 1 + out_clusters.append(ncluster) + out.append(out_clusters) + return out def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: - ... + for doc, clusters in zip(docs, clusters_by_doc): + for ii, cluster in enumerate(clusters): + spans = [doc[mm[0]:mm[1]] for mm in cluster] + doc.spans[f"{self.output_prefix}_{ii}"] = spans def update( self, @@ -421,7 +498,33 @@ class SpanPredictor(TrainablePipe): sgd: Optional[Optimizer] = None, losses: Optional[Dict[str, float]] = None, ) -> Dict[str, float]: - ... + """Learn from a batch of documents and gold-standard information, + updating the pipe's model. Delegates to predict and get_loss. + """ + if losses is None: + losses = {} + losses.setdefault(self.name, 0.0) + validate_examples(examples, "SpanPredictor.update") + if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): + # Handle cases where there are no tokens in any docs. + return losses + set_dropout_rate(self.model, drop) + + total_loss = 0 + + for eg in examples: + preds, backprop = self.model.begin_update([eg.predicted]) + score_matrix, mention_idx = preds + + loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) + total_loss += loss + # TODO check shape here + backprop((d_scores, mention_idx)) + + if sgd is not None: + self.finish_update(sgd) + losses[self.name] += total_loss + return losses def rehearse( self, @@ -431,7 +534,12 @@ class SpanPredictor(TrainablePipe): sgd: Optional[Optimizer] = None, losses: Optional[Dict[str, float]] = None, ) -> Dict[str, float]: - ... + # TODO this should be added later + raise NotImplementedError( + Errors.E931.format( + parent="SpanPredictor", method="add_label", name=self.name + ) + ) def add_label(self, label: str) -> int: """Technically this method should be implemented from TrainablePipe, @@ -446,9 +554,39 @@ class SpanPredictor(TrainablePipe): def get_loss( self, examples: Iterable[Example], - # TODO add necessary args + span_scores: Floats3d, ): - ... + ops = self.model.ops + + # NOTE This is doing fake batching, and should always get a list of one example + assert len(examples) == 1, "Only fake batching is supported." + # starts and ends are gold starts and ends (Ints1d) + # span_scores is a Floats3d. What are the axes? mention x token x start/end + + for eg in examples: + + # get gold data + gold = doc2clusters(eg.reference, self.output_prefix) + # flatten the gold data + starts = [] + ends = [] + for cluster in gold: + for mention in cluster: + starts.append(mention[0]) + ends.append(mention[1]) + + start_scores = span_scores[:, :, 0] + end_scores = span_scores[:, :, 1] + n_classes = start_scores.shape[1] + start_probs = ops.softmax(start_scores, axis=1) + end_probs = ops.softmax(end_scores, axis=1) + start_targets = to_categorical(starts, n_classes) + end_targets = to_categorical(ends, n_classes) + start_grads = (start_probs - start_targets) + end_grads = (end_probs - end_targets) + grads = ops.xp.stack((start_grads, end_grads), axis=2) + loss = float((grads ** 2).sum()) + return loss, grads def initialize( self, @@ -461,6 +599,12 @@ class SpanPredictor(TrainablePipe): X = [] Y = [] for ex in islice(get_examples(), 2): + + if not ex.predicted.spans: + # set placeholder for shape inference + doc = ex.predicted + assert len(doc) > 2, "Coreference requires at least two tokens" + doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]] X.append(ex.predicted) Y.append(ex.reference) @@ -468,5 +612,31 @@ class SpanPredictor(TrainablePipe): self.model.initialize(X=X, Y=Y) def score(self, examples, **kwargs): - # TODO this will overlap significantly with coref, maybe factor into function - ... + """Score a batch of examples.""" + # TODO This is basically the same as the main coref component - factor out? + + scores = [] + for metric in (b_cubed, muc, ceafe): + evaluator = Evaluator(metric) + + for ex in examples: + # XXX this is the only different part + p_clusters = doc2clusters(ex.predicted, self.output_prefix) + g_clusters = doc2clusters(ex.reference, self.output_prefix) + + cluster_info = get_cluster_info(p_clusters, g_clusters) + + evaluator.update(cluster_info) + + score = { + "coref_f": evaluator.get_f1(), + "coref_p": evaluator.get_precision(), + "coref_r": evaluator.get_recall(), + } + scores.append(score) + + out = {} + for field in ("f", "p", "r"): + fname = f"coref_{field}" + out[fname] = mean([ss[fname] for ss in scores]) + return out