diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 0f1614ef5..0b533daf0 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -37,14 +37,11 @@ def build_wl_coref_model( except ValueError: # happens with transformer listener dim = 768 - + with Model.define_operators({">>": chain}): # TODO chain tok2vec with these models - # TODO fix device - should be automatic - device = "cuda:0" coref_scorer = PyTorchWrapper( CorefScorer( - device, dim, embedding_size, hidden_size, @@ -56,8 +53,16 @@ def build_wl_coref_model( convert_inputs=convert_coref_scorer_inputs, convert_outputs=convert_coref_scorer_outputs ) - coref_model = tok2vec >> coref_scorer + # XXX just ignore this until the coref scorer is integrated + # span_predictor = PyTorchWrapper( + # SpanPredictor( + # TODO this was hardcoded to 1024, check + # hidden_size, + # sp_embedding_size, + # ), + # convert_inputs=convert_span_predictor_inputs + # ) # TODO combine models so output is uniform (just one forward pass) # It may be reasonable to have an option to disable span prediction, # and just return words as spans. @@ -77,14 +82,14 @@ def build_span_predictor( dim = 768 with Model.define_operators({">>": chain, "&": tuplify}): - # TODO fix device - should be automatic - device = "cuda:0" span_predictor = PyTorchWrapper( - SpanPredictor(dim, hidden_size, dist_emb_size, device), + SpanPredictor(dim, hidden_size, dist_emb_size), convert_inputs=convert_span_predictor_inputs ) # TODO use proper parameter for prefix - head_info = build_get_head_metadata("coref_head_clusters") + head_info = build_get_head_metadata( + "coref_head_clusters" + ) model = (tok2vec & head_info) >> span_predictor return model @@ -99,13 +104,13 @@ def convert_coref_scorer_inputs( # just use the first # TODO real batching X = X[0] - - word_features = xp2torch(X, requires_grad=is_train) + def backprop(args: ArgsKwargs) -> List[Floats2d]: # convert to xp and wrap in list gradients = torch2xp(args.args[0]) return [gradients] + return ArgsKwargs(args=(word_features, ), kwargs={}), backprop @@ -128,6 +133,7 @@ def convert_coref_scorer_outputs( indices_xp = torch2xp(indices) return (scores_xp, indices_xp), convert_for_torch_backward + def convert_span_predictor_inputs( model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], @@ -135,14 +141,23 @@ def convert_span_predictor_inputs( ): tok2vec, (sent_ids, head_ids) = X # Normally we shoudl use the input is_train, but for these two it's not relevant - sent_ids = xp2torch(sent_ids[0], requires_grad=False) - head_ids = xp2torch(head_ids[0], requires_grad=False) + + def backprop(args: ArgsKwargs) -> List[Floats2d]: + # convert to xp and wrap in list + gradients = torch2xp(args.args[1]) + return [[gradients], None] word_features = xp2torch(tok2vec[0], requires_grad=is_train) + sent_ids = xp2torch(sent_ids[0], requires_grad=False) + if not head_ids[0].size: + head_ids = torch.empty(size=(0,)) + else: + head_ids = xp2torch(head_ids[0], requires_grad=False) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) # TODO actually support backprop - return argskwargs, lambda dX: [] + return argskwargs, backprop + # TODO This probably belongs in the component, not the model. def predict_span_clusters(span_predictor: Model, @@ -211,18 +226,21 @@ def _clusterize( clusters.append(sorted(cluster)) return sorted(clusters) + def build_get_head_metadata(prefix): # TODO this name is awful, fix it - model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward) + model = Model("HeadDataProvider", + attrs={'prefix': prefix}, + forward=head_data_forward) return model + def head_data_forward(model, docs, is_train): """A layer to generate the extra data needed for the span predictor. """ sent_ids = [] head_ids = [] prefix = model.attrs["prefix"] - for doc in docs: sids = model.ops.asarray2i(get_sentence_ids(doc)) sent_ids.append(sids) @@ -235,7 +253,6 @@ def head_data_forward(model, docs, is_train): heads.append(span[0].i) heads = model.ops.asarray2i(heads) head_ids.append(heads) - # each of these is a list with one entry per doc # backprop is just a placeholder # TODO it would probably be better to have a list of tuples than two lists of arrays @@ -256,7 +273,6 @@ class CorefScorer(torch.nn.Module): """ def __init__( self, - device: str, dim: int, # tok2vec size dist_emb_size: int, hidden_size: int, @@ -273,8 +289,7 @@ class CorefScorer(torch.nn.Module): epochs_trained (int): the number of epochs finished (useful for warm start) """ - # device, dist_emb_size, hidden_size, n_layers, dropout_rate - self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device) + self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) #TODO clean this up bert_emb = dim pair_emb = bert_emb * 3 + self.pw.shape @@ -283,7 +298,7 @@ class CorefScorer(torch.nn.Module): hidden_size, n_layers, dropout_rate - ).to(device) + ) self.lstm = torch.nn.LSTM( input_size=bert_emb, hidden_size=bert_emb, @@ -294,7 +309,7 @@ class CorefScorer(torch.nn.Module): bert_emb, dropout_rate, roughk - ).to(device) + ) self.batch_size = batch_size def forward( @@ -443,7 +458,6 @@ class AnaphoricityScorer(torch.nn.Module): return out - class RoughScorer(torch.nn.Module): """ Is needed to give a roughly estimate of the anaphoricity of two candidates, @@ -474,7 +488,6 @@ class RoughScorer(torch.nn.Module): pair_mask = torch.arange(mentions.shape[0]) pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0) pair_mask = torch.log((pair_mask > 0).to(torch.float)) - pair_mask = pair_mask.to(mentions.device) bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) rough_scores = pair_mask + bilinear_scores @@ -501,7 +514,7 @@ class RoughScorer(torch.nn.Module): class SpanPredictor(torch.nn.Module): - def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device): + def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int): super().__init__() # input size = single token size # 64 = probably distance emb size @@ -517,7 +530,6 @@ class SpanPredictor(torch.nn.Module): # this use of dist_emb_size looks wrong but it was 64...? torch.nn.Linear(256, dist_emb_size), ) - self.device = device self.conv = torch.nn.Sequential( torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) @@ -541,17 +553,18 @@ class SpanPredictor(torch.nn.Module): Returns: torch.Tensor: span start/end scores, [n_heads, n_words, 2] """ + # If we don't receive heads, return empty + if heads_ids.nelement() == 0: + return torch.empty(size=(0,)) # Obtain distance embedding indices, [n_heads, n_words] - relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0)) + relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0)) # make all valid distances positive emb_ids = relative_positions + 63 # "too_far" emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # Obtain "same sentence" boolean mask, [n_heads, n_words] - sent_id = torch.tensor(sent_id, device=words.device) heads_ids = heads_ids.long() same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) - # To save memory, only pass candidates from one sentence for each head # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # for each candidate among the words in the same sentence as span_head @@ -562,23 +575,20 @@ class SpanPredictor(torch.nn.Module): words[cols], self.emb(emb_ids[rows, cols]), ), dim=1) - lengths = same_sent.sum(dim=1) - padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0) + padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] - # [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # This is necessary to allow the convolution layer to look at several # word scores - padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device) + padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1]) padded_pairs[padding_mask] = pair_matrix res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output] res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2] - scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device) + scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf')) scores[rows, cols] = res[padding_mask] - # Make sure that start <= head <= end during inference if not self.training: valid_starts = torch.log((relative_positions >= 0).to(torch.float)) @@ -586,6 +596,7 @@ class SpanPredictor(torch.nn.Module): valid_positions = torch.stack((valid_starts, valid_ends), dim=2) return scores + valid_positions return scores + class DistancePairwiseEncoder(torch.nn.Module): def __init__(self, embedding_size, dropout_rate): @@ -595,17 +606,10 @@ class DistancePairwiseEncoder(torch.nn.Module): self.dropout = torch.nn.Dropout(dropout_rate) self.shape = emb_size - @property - def device(self) -> torch.device: - """ A workaround to get current device (which is assumed to be the - device of the first parameter of one of the submodules) """ - return next(self.distance_emb.parameters()).device - - def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch top_indices: torch.Tensor ) -> torch.Tensor: - word_ids = torch.arange(0, top_indices.size(0), device=self.device) + word_ids = torch.arange(0, top_indices.size(0)) distance = (word_ids.unsqueeze(1) - word_ids[top_indices] ).clamp_min_(min=1) log_distance = distance.to(torch.float).log2().floor_() diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 54e9d8cfd..fc04d1a3e 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -3,7 +3,7 @@ import warnings from thinc.types import Floats2d, Floats3d, Ints2d from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy -from thinc.api import set_dropout_rate +from thinc.api import set_dropout_rate, to_categorical from itertools import islice from statistics import mean @@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe): DOCS: https://spacy.io/api/coref#predict (TODO) """ - #print("DOCS", docs) out = [] for doc in docs: scores, idxs = self.model.predict([doc]) @@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe): # TODO check this causes no issues (in practice it runs) preds, backprop = self.model.begin_update([eg.predicted]) score_matrix, mention_idx = preds - loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) total_loss += loss # TODO check shape here @@ -366,9 +364,7 @@ class CoreferenceResolver(TrainablePipe): for ex in examples: p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix) g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix) - cluster_info = get_cluster_info(p_clusters, g_clusters) - evaluator.update(cluster_info) score = { @@ -460,27 +456,29 @@ class SpanPredictor(TrainablePipe): out = [] for doc in docs: # TODO check shape here - span_scores = self.model.predict(doc) - span_scores = span_scores[0] - # the information about clustering has to come from the input docs - # first let's convert the scores to a list of span idxs - start_scores = span_scores[:, :, 0] - end_scores = span_scores[:, :, 1] - starts = start_scores.argmax(axis=1) - ends = end_scores.argmax(axis=1) + span_scores = self.model.predict([doc]) + if span_scores.size: + # the information about clustering has to come from the input docs + # first let's convert the scores to a list of span idxs + start_scores = span_scores[:, :, 0] + end_scores = span_scores[:, :, 1] + starts = start_scores.argmax(axis=1) + ends = end_scores.argmax(axis=1) - # TODO check start < end + # TODO check start < end - # get the old clusters (shape will be preserved) - clusters = doc2clusters(doc, self.input_prefix) - cidx = 0 - out_clusters = [] - for cluster in clusters: - ncluster = [] - for mention in cluster: - ncluster.append( (starts[cidx], ends[cidx]) ) - cidx += 1 - out_clusters.append(ncluster) + # get the old clusters (shape will be preserved) + clusters = doc2clusters(doc, self.input_prefix) + cidx = 0 + out_clusters = [] + for cluster in clusters: + ncluster = [] + for mention in cluster: + ncluster.append((starts[cidx], ends[cidx])) + cidx += 1 + out_clusters.append(ncluster) + else: + out_clusters = [] out.append(out_clusters) return out @@ -505,21 +503,21 @@ class SpanPredictor(TrainablePipe): losses = {} losses.setdefault(self.name, 0.0) validate_examples(examples, "SpanPredictor.update") - if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): + if not any(len(eg.reference) if eg.reference else 0 for eg in examples): # Handle cases where there are no tokens in any docs. return losses set_dropout_rate(self.model, drop) total_loss = 0 - for eg in examples: - preds, backprop = self.model.begin_update([eg.predicted]) - score_matrix, mention_idx = preds - - loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) - total_loss += loss - # TODO check shape here - backprop((d_scores, mention_idx)) + span_scores, backprop = self.model.begin_update([eg.predicted]) + # FIXME, this only happens once in the first 1000 docs of OntoNotes + # and I'm not sure yet why. + if span_scores.size: + loss, d_scores = self.get_loss([eg], span_scores) + total_loss += loss + # TODO check shape here + backprop((d_scores)) if sgd is not None: self.finish_update(sgd) @@ -562,19 +560,17 @@ class SpanPredictor(TrainablePipe): assert len(examples) == 1, "Only fake batching is supported." # starts and ends are gold starts and ends (Ints1d) # span_scores is a Floats3d. What are the axes? mention x token x start/end - for eg in examples: - - # get gold data - gold = doc2clusters(eg.reference, self.output_prefix) - # flatten the gold data starts = [] ends = [] - for cluster in gold: - for mention in cluster: - starts.append(mention[0]) - ends.append(mention[1]) + for key, sg in eg.reference.spans.items(): + if key.startswith(self.output_prefix): + for mention in sg: + starts.append(mention.start) + ends.append(mention.end) + starts = self.model.ops.xp.asarray(starts) + ends = self.model.ops.xp.asarray(ends) start_scores = span_scores[:, :, 0] end_scores = span_scores[:, :, 1] n_classes = start_scores.shape[1] @@ -594,7 +590,7 @@ class SpanPredictor(TrainablePipe): *, nlp: Optional[Language] = None, ) -> None: - validate_get_examples(get_examples, "CoreferenceResolver.initialize") + validate_get_examples(get_examples, "SpanPredictor.initialize") X = [] Y = [] @@ -612,31 +608,33 @@ class SpanPredictor(TrainablePipe): self.model.initialize(X=X, Y=Y) def score(self, examples, **kwargs): - """Score a batch of examples.""" - # TODO This is basically the same as the main coref component - factor out? - + """ + Evaluate on reconstructing the correct spans around + gold heads. + """ scores = [] - for metric in (b_cubed, muc, ceafe): - evaluator = Evaluator(metric) + xp = self.model.ops.xp + for eg in examples: + starts = [] + ends = [] + pred_starts = [] + pred_ends = [] + ref = eg.reference + pred = eg.predicted + for key, gold_sg in ref.spans.items(): + if key.startswith(self.output_prefix): + pred_sg = pred.spans[key] + for gold_mention, pred_mention in zip(gold_sg, pred_sg): + starts.append(gold_mention.start) + ends.append(gold_mention.end) + pred_starts.append(pred_mention.start) + pred_ends.append(pred_mention.end) - for ex in examples: - # XXX this is the only different part - p_clusters = doc2clusters(ex.predicted, self.output_prefix) - g_clusters = doc2clusters(ex.reference, self.output_prefix) - - cluster_info = get_cluster_info(p_clusters, g_clusters) - - evaluator.update(cluster_info) - - score = { - "coref_f": evaluator.get_f1(), - "coref_p": evaluator.get_precision(), - "coref_r": evaluator.get_recall(), - } - scores.append(score) - - out = {} - for field in ("f", "p", "r"): - fname = f"coref_{field}" - out[fname] = mean([ss[fname] for ss in scores]) - return out + starts = xp.asarray(starts) + ends = xp.asarray(ends) + pred_starts = xp.asarray(pred_starts) + pred_ends = xp.asarray(pred_ends) + correct = (starts == pred_starts) * (ends == pred_ends) + accuracy = correct.mean() + scores.append(float(accuracy)) + return {"span_accuracy": mean(scores)}