mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 15:37:29 +03:00 
			
		
		
		
	
						commit
						57165f9631
					
				|  | @ -1,14 +1,14 @@ | |||
| from typing import List, Tuple | ||||
| import torch | ||||
| 
 | ||||
| from thinc.api import Model, chain, tuplify | ||||
| from thinc.api import Model, chain | ||||
| from thinc.api import PyTorchWrapper, ArgsKwargs | ||||
| from thinc.types import Floats2d, Ints1d, Ints2d | ||||
| from thinc.types import Floats2d, Ints2d, Ints1d | ||||
| from thinc.util import xp2torch, torch2xp | ||||
| 
 | ||||
| from ...tokens import Doc | ||||
| from ...util import registry | ||||
| from .coref_util import add_dummy, get_sentence_ids | ||||
| from .coref_util import add_dummy | ||||
| 
 | ||||
| 
 | ||||
| @registry.architectures("spacy.Coref.v1") | ||||
|  | @ -19,7 +19,6 @@ def build_wl_coref_model( | |||
|     n_hidden_layers: int = 1,  # TODO rename to "depth"? | ||||
|     dropout: float = 0.3, | ||||
|     # pairs to keep per mention after rough scoring | ||||
|     # TODO change to meaningful name | ||||
|     rough_k: int = 50, | ||||
|     # TODO is this not a training loop setting? | ||||
|     a_scoring_batch_size: int = 512, | ||||
|  | @ -34,7 +33,6 @@ def build_wl_coref_model( | |||
|         dim = 768 | ||||
| 
 | ||||
|     with Model.define_operators({">>": chain}): | ||||
|         # TODO chain tok2vec with these models | ||||
|         coref_scorer = PyTorchWrapper( | ||||
|             CorefScorer( | ||||
|                 dim, | ||||
|  | @ -49,22 +47,14 @@ def build_wl_coref_model( | |||
|             convert_outputs=convert_coref_scorer_outputs, | ||||
|         ) | ||||
|         coref_model = tok2vec >> coref_scorer | ||||
|         # XXX just ignore this until the coref scorer is integrated | ||||
|         # span_predictor = PyTorchWrapper( | ||||
|         #    SpanPredictor( | ||||
|         # TODO this was hardcoded to 1024, check | ||||
|         #        hidden_size, | ||||
|         #        sp_embedding_size, | ||||
|         #    ), | ||||
|         #    convert_inputs=convert_span_predictor_inputs | ||||
|         # ) | ||||
|     # TODO combine models so output is uniform (just one forward pass) | ||||
|     # It may be reasonable to have an option to disable span prediction, | ||||
|     # and just return words as spans. | ||||
|     return coref_model | ||||
| 
 | ||||
| 
 | ||||
| def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool): | ||||
| def convert_coref_scorer_inputs( | ||||
|     model: Model, | ||||
|     X: List[Floats2d], | ||||
|     is_train: bool | ||||
| ): | ||||
|     # The input here is List[Floats2d], one for each doc | ||||
|     # just use the first | ||||
|     # TODO real batching | ||||
|  | @ -76,7 +66,7 @@ def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool) | |||
|         gradients = torch2xp(args.args[0]) | ||||
|         return [gradients] | ||||
| 
 | ||||
|     return ArgsKwargs(args=(word_features,), kwargs={}), backprop | ||||
|     return ArgsKwargs(args=(word_features, ), kwargs={}), backprop | ||||
| 
 | ||||
| 
 | ||||
| def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): | ||||
|  | @ -95,46 +85,13 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): | |||
|     return (scores_xp, indices_xp), convert_for_torch_backward | ||||
| 
 | ||||
| 
 | ||||
| # TODO add docstring for this, maybe move to utils. | ||||
| # This might belong in the component. | ||||
| def _clusterize(model, scores: Floats2d, top_indices: Ints2d): | ||||
|     xp = model.ops.xp | ||||
|     antecedents = scores.argmax(axis=1) - 1 | ||||
|     not_dummy = antecedents >= 0 | ||||
|     coref_span_heads = xp.arange(0, len(scores))[not_dummy] | ||||
|     antecedents = top_indices[coref_span_heads, antecedents[not_dummy]] | ||||
|     n_words = scores.shape[0] | ||||
|     nodes = [GraphNode(i) for i in range(n_words)] | ||||
|     for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()): | ||||
|         nodes[i].link(nodes[j]) | ||||
|         assert nodes[i] is not nodes[j] | ||||
| 
 | ||||
|     clusters = [] | ||||
|     for node in nodes: | ||||
|         if len(node.links) > 0 and not node.visited: | ||||
|             cluster = [] | ||||
|             stack = [node] | ||||
|             while stack: | ||||
|                 current_node = stack.pop() | ||||
|                 current_node.visited = True | ||||
|                 cluster.append(current_node.id) | ||||
|                 stack.extend(link for link in current_node.links if not link.visited) | ||||
|             assert len(cluster) > 1 | ||||
|             clusters.append(sorted(cluster)) | ||||
|     return sorted(clusters) | ||||
| 
 | ||||
| 
 | ||||
| class CorefScorer(torch.nn.Module): | ||||
|     """Combines all coref modules together to find coreferent spans. | ||||
| 
 | ||||
|     Attributes: | ||||
|         epochs_trained (int): number of epochs the model has been trained for | ||||
| 
 | ||||
|     """ | ||||
|     Combines all coref modules together to find coreferent token pairs. | ||||
|     Submodules (in the order of their usage in the pipeline): | ||||
|         rough_scorer (RoughScorer) | ||||
|         pw (PairwiseEncoder) | ||||
|         a_scorer (AnaphoricityScorer) | ||||
|         sp (SpanPredictor) | ||||
|         - rough_scorer (RoughScorer) that prunes candidate pairs | ||||
|         - pw (DistancePairwiseEncoder) that computes pairwise features | ||||
|         - a_scorer (AnaphoricityScorer) produces the final scores | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|  | @ -149,50 +106,64 @@ class CorefScorer(torch.nn.Module): | |||
|     ): | ||||
|         super().__init__() | ||||
|         """ | ||||
|         A newly created model is set to evaluation mode. | ||||
| 
 | ||||
|         Args: | ||||
|             epochs_trained (int): the number of epochs finished | ||||
|                 (useful for warm start) | ||||
|         dim: Size of the input features. | ||||
|         dist_emb_size: Size of the distance embeddings. | ||||
|         hidden_size: Size of the coreference candidate embeddings. | ||||
|         n_layers: Numbers of layers in the AnaphoricityScorer. | ||||
|         dropout_rate: Dropout probability to apply across all modules. | ||||
|         roughk: Number of candidates the RoughScorer returns. | ||||
|         batch_size: Internal batch-size for the more expensive scorer. | ||||
|         """ | ||||
|         self.dropout = torch.nn.Dropout(dropout_rate) | ||||
|         self.batch_size = batch_size | ||||
|         # Modules | ||||
|         self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) | ||||
|         # TODO clean this up | ||||
|         bert_emb = dim | ||||
|         pair_emb = bert_emb * 3 + self.pw.shape | ||||
|         pair_emb = dim * 3 + self.pw.shape | ||||
|         self.a_scorer = AnaphoricityScorer( | ||||
|             pair_emb, | ||||
|             hidden_size, | ||||
|             n_layers, | ||||
|             dropout_rate | ||||
|         ) | ||||
|         self.lstm = torch.nn.LSTM( | ||||
|             input_size=dim, | ||||
|             hidden_size=dim, | ||||
|             batch_first=True, | ||||
|         ) | ||||
|         self.rough_scorer = RoughScorer(dim, dropout_rate, roughk) | ||||
|         self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) | ||||
|         pair_emb = dim * 3 + self.pw.shape | ||||
|         self.a_scorer = AnaphoricityScorer( | ||||
|             pair_emb, hidden_size, n_layers, dropout_rate | ||||
|         ) | ||||
|         self.lstm = torch.nn.LSTM( | ||||
|             input_size=bert_emb, | ||||
|             hidden_size=bert_emb, | ||||
|             batch_first=True, | ||||
|         ) | ||||
|         self.dropout = torch.nn.Dropout(dropout_rate) | ||||
|         self.rough_scorer = RoughScorer(bert_emb, dropout_rate, roughk) | ||||
|         self.batch_size = batch_size | ||||
| 
 | ||||
|     def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|     def forward( | ||||
|         self, word_features: torch.Tensor | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         """ | ||||
|         This is a massive method, but it made sense to me to not split it into | ||||
|         several ones to let one see the data flow. | ||||
|         1. LSTM encodes the incoming word_features. | ||||
|         2. The RoughScorer scores and prunes the candidates. | ||||
|         3. The DistancePairwiseEncoder embeds the distances between pairs. | ||||
|         4. The AnaphoricityScorer scores all pairs in mini-batches. | ||||
| 
 | ||||
|         Args: | ||||
|             word_features: torch.Tensor containing word encodings | ||||
|         Returns: | ||||
|             coreference scores and top indices | ||||
|         word_features: torch.Tensor containing word encodings | ||||
| 
 | ||||
|         returns: | ||||
|             coref_scores: n_words x roughk floats. | ||||
|             top_indices: n_words x roughk integers. | ||||
|         """ | ||||
|         # words           [n_words, span_emb] | ||||
|         # cluster_ids     [n_words] | ||||
|         self.lstm.flatten_parameters()  # XXX without this there's a warning | ||||
|         word_features = torch.unsqueeze(word_features, dim=0) | ||||
|         words, _ = self.lstm(word_features) | ||||
|         words = words.squeeze() | ||||
|         # words: n_words x dim | ||||
|         words = self.dropout(words) | ||||
|         # Obtain bilinear scores and leave only top-k antecedents for each word | ||||
|         # top_rough_scores  [n_words, n_ants] | ||||
|         # top_indices       [n_words, n_ants] | ||||
|         # top_rough_scores: (n_words x roughk) | ||||
|         # top_indices: (n_words x roughk) | ||||
|         top_rough_scores, top_indices = self.rough_scorer(words) | ||||
|         # Get pairwise features [n_words, n_ants, n_pw_features] | ||||
|         # Get pairwise features | ||||
|         # (n_words x roughk x n_pw_features) | ||||
|         pw = self.pw(top_indices) | ||||
|         batch_size = self.batch_size | ||||
|         a_scores_lst: List[torch.Tensor] = [] | ||||
|  | @ -272,13 +243,8 @@ class AnaphoricityScorer(torch.nn.Module): | |||
| 
 | ||||
|     def _ffnn(self, x: torch.Tensor) -> torch.Tensor: | ||||
|         """ | ||||
|         Calculates anaphoricity scores. | ||||
| 
 | ||||
|         Args: | ||||
|             x: tensor of shape [batch_size, n_ants, n_features] | ||||
| 
 | ||||
|         Returns: | ||||
|             tensor of shape [batch_size, n_ants] | ||||
|         x: tensor of shape (batch_size x roughk x n_features | ||||
|         returns: tensor of shape (batch_size x rough_k) | ||||
|         """ | ||||
|         x = self.out(self.hidden(x)) | ||||
|         return x.squeeze(2) | ||||
|  | @ -293,21 +259,18 @@ class AnaphoricityScorer(torch.nn.Module): | |||
|         """ | ||||
|         Builds the matrix used as input for AnaphoricityScorer. | ||||
| 
 | ||||
|         Args: | ||||
|             all_mentions (torch.Tensor): [n_mentions, mention_emb], | ||||
|                 all the valid mentions of the document, | ||||
|                 can be on a different device | ||||
|             mentions_batch (torch.Tensor): [batch_size, mention_emb], | ||||
|                 the mentions of the current batch, | ||||
|                 is expected to be on the current device | ||||
|             pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb], | ||||
|                 pairwise features of the current batch, | ||||
|                 is expected to be on the current device | ||||
|             top_indices_batch (torch.Tensor): [batch_size, n_ants], | ||||
|                 indices of antecedents of each mention | ||||
|         all_mentions: (n_mentions x mention_emb), | ||||
|             all the valid mentions of the document, | ||||
|             can be on a different device | ||||
|         mentions_batch: (batch_size x mention_emb), | ||||
|             the mentions of the current batch. | ||||
|         pw_batch: (batch_size x roughk x pw_emb), | ||||
|             pairwise distance features of the current batch. | ||||
|         top_indices_batch: (batch_size x n_ants), | ||||
|             indices of antecedents of each mention | ||||
| 
 | ||||
|         Returns: | ||||
|             torch.Tensor: [batch_size, n_ants, pair_emb] | ||||
|             out: pairwise features (batch_size x n_ants x pair_emb) | ||||
|         """ | ||||
|         emb_size = mentions_batch.shape[1] | ||||
|         n_ants = pw_batch.shape[1] | ||||
|  | @ -322,16 +285,15 @@ class AnaphoricityScorer(torch.nn.Module): | |||
| 
 | ||||
| class RoughScorer(torch.nn.Module): | ||||
|     """ | ||||
|     Is needed to give a roughly estimate of the anaphoricity of two candidates, | ||||
|     only top scoring candidates are considered on later steps to reduce | ||||
|     computational complexity. | ||||
|     Cheaper module that gives a rough estimate of the anaphoricity of two | ||||
|     candidates, only top scoring candidates are considered on later | ||||
|     steps to reduce computational cost. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, features: int, dropout_rate: float, rough_k: float): | ||||
|         super().__init__() | ||||
|         self.dropout = torch.nn.Dropout(dropout_rate) | ||||
|         self.bilinear = torch.nn.Linear(features, features) | ||||
| 
 | ||||
|         self.k = rough_k | ||||
| 
 | ||||
|     def forward( | ||||
|  | @ -348,29 +310,27 @@ class RoughScorer(torch.nn.Module): | |||
|         pair_mask = torch.log((pair_mask > 0).to(torch.float)) | ||||
|         bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) | ||||
|         rough_scores = pair_mask + bilinear_scores | ||||
| 
 | ||||
|         return self._prune(rough_scores) | ||||
| 
 | ||||
|     def _prune(self, rough_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         """ | ||||
|         Selects top-k rough antecedent scores for each mention. | ||||
| 
 | ||||
|         Args: | ||||
|             rough_scores: tensor of shape [n_mentions, n_mentions], containing | ||||
|                 rough antecedent scores of each mention-antecedent pair. | ||||
| 
 | ||||
|         Returns: | ||||
|             FloatTensor of shape [n_mentions, k], top rough scores | ||||
|             LongTensor of shape [n_mentions, k], top indices | ||||
|         """ | ||||
|         top_scores, indices = torch.topk( | ||||
|             rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False | ||||
|         ) | ||||
| 
 | ||||
|         return top_scores, indices | ||||
| 
 | ||||
| 
 | ||||
| class DistancePairwiseEncoder(torch.nn.Module): | ||||
|     def __init__(self, embedding_size, dropout_rate): | ||||
|         """ | ||||
|         Takes the top_indices indicating, which is a ranked | ||||
|         list for each word and its most likely corresponding | ||||
|         anaphora candidates. For each of these pairs it looks | ||||
|         up a distance embedding from a table, where the distance | ||||
|         corresponds to the log-distance. | ||||
| 
 | ||||
|         embedding_size: int, | ||||
|             Dimensionality of the distance-embeddings table. | ||||
|         dropout_rate: float, | ||||
|             Dropout probability. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         emb_size = embedding_size | ||||
|         self.distance_emb = torch.nn.Embedding(9, emb_size) | ||||
|  | @ -378,11 +338,12 @@ class DistancePairwiseEncoder(torch.nn.Module): | |||
|         self.shape = emb_size | ||||
| 
 | ||||
|     def forward( | ||||
|         self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|         top_indices: torch.Tensor, | ||||
|         self, | ||||
|         top_indices: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         word_ids = torch.arange(0, top_indices.size(0)) | ||||
|         distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1) | ||||
|         distance = (word_ids.unsqueeze(1) - word_ids[top_indices] | ||||
|                     ).clamp_min_(min=1) | ||||
|         log_distance = distance.to(torch.float).log2().floor_() | ||||
|         log_distance = log_distance.clamp_max_(max=6).to(torch.long) | ||||
|         distance = torch.where(distance < 5, distance - 1, log_distance + 2) | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ import torch | |||
| 
 | ||||
| from thinc.api import Model, chain, tuplify | ||||
| from thinc.api import PyTorchWrapper, ArgsKwargs | ||||
| from thinc.types import Floats2d, Ints1d, Ints2d | ||||
| from thinc.types import Floats2d, Ints1d | ||||
| from thinc.util import xp2torch, torch2xp | ||||
| 
 | ||||
| from ...tokens import Doc | ||||
|  | @ -40,10 +40,9 @@ def convert_span_predictor_inputs( | |||
|     model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool | ||||
| ): | ||||
|     tok2vec, (sent_ids, head_ids) = X | ||||
|     # Normally we shoudl use the input is_train, but for these two it's not relevant | ||||
|     # Normally we should use the input is_train, but for these two it's not relevant | ||||
| 
 | ||||
|     def backprop(args: ArgsKwargs) -> List[Floats2d]: | ||||
|         # convert to xp and wrap in list | ||||
|         gradients = torch2xp(args.args[1]) | ||||
|         return [[gradients], None] | ||||
| 
 | ||||
|  | @ -55,7 +54,6 @@ def convert_span_predictor_inputs( | |||
|         head_ids = xp2torch(head_ids[0], requires_grad=False) | ||||
| 
 | ||||
|     argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) | ||||
|     # TODO actually support backprop | ||||
|     return argskwargs, backprop | ||||
| 
 | ||||
| 
 | ||||
|  | @ -66,15 +64,13 @@ def predict_span_clusters( | |||
|     """ | ||||
|     Predicts span clusters based on the word clusters. | ||||
| 
 | ||||
|     Args: | ||||
|         doc (Doc): the document data | ||||
|         words (torch.Tensor): [n_words, emb_size] matrix containing | ||||
|             embeddings for each of the words in the text | ||||
|         clusters (List[List[int]]): a list of clusters where each cluster | ||||
|             is a list of word indices | ||||
|     span_predictor: a SpanPredictor instance | ||||
|     sent_ids: For each word indicates, which sentence it appears in. | ||||
|     words: Features for words. | ||||
|     clusters: Clusters inferred by the CorefScorer. | ||||
| 
 | ||||
|     Returns: | ||||
|         List[List[Span]]: span clusters | ||||
|         List[List[Tuple[int, int]]: span clusters | ||||
|     """ | ||||
|     if not clusters: | ||||
|         return [] | ||||
|  | @ -141,29 +137,29 @@ class SpanPredictor(torch.nn.Module): | |||
|             # this use of dist_emb_size looks wrong but it was 64...? | ||||
|             torch.nn.Linear(256, dist_emb_size), | ||||
|         ) | ||||
|         # TODO make the Convs also parametrizeable | ||||
|         self.conv = torch.nn.Sequential( | ||||
|             torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) | ||||
|         ) | ||||
|         # TODO make embeddings size a parameter | ||||
|         self.emb = torch.nn.Embedding(128, dist_emb_size)  # [-63, 63] + too_far | ||||
| 
 | ||||
|     def forward( | ||||
|         self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|         self, | ||||
|         sent_id, | ||||
|         words: torch.Tensor, | ||||
|         heads_ids: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Calculates span start/end scores of words for each span head in | ||||
|         heads_ids | ||||
|         Calculates span start/end scores of words for each span | ||||
|         for each head. | ||||
| 
 | ||||
|         Args: | ||||
|             doc (Doc): the document data | ||||
|             words (torch.Tensor): contextual embeddings for each word in the | ||||
|                 document, [n_words, emb_size] | ||||
|             heads_ids (torch.Tensor): word indices of span heads | ||||
|         sent_id: Sentence id of each word. | ||||
|         words: features for each word in the document. | ||||
|         heads_ids: word indices of span heads | ||||
| 
 | ||||
|         Returns: | ||||
|             torch.Tensor: span start/end scores, [n_heads, n_words, 2] | ||||
|             torch.Tensor: span start/end scores, (n_heads x n_words x 2) | ||||
|         """ | ||||
|         # If we don't receive heads, return empty | ||||
|         if heads_ids.nelement() == 0: | ||||
|  | @ -176,13 +172,13 @@ class SpanPredictor(torch.nn.Module): | |||
|         emb_ids = relative_positions + 63 | ||||
|         # "too_far" | ||||
|         emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 | ||||
|         # Obtain "same sentence" boolean mask, [n_heads, n_words] | ||||
|         # Obtain "same sentence" boolean mask: (n_heads x n_words) | ||||
|         heads_ids = heads_ids.long() | ||||
|         same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0) | ||||
|         # To save memory, only pass candidates from one sentence for each head | ||||
|         # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb | ||||
|         # for each candidate among the words in the same sentence as span_head | ||||
|         # [n_heads, input_size * 2 + distance_emb_size] | ||||
|         # (n_heads x input_size * 2 x distance_emb_size) | ||||
|         rows, cols = same_sent.nonzero(as_tuple=True) | ||||
|         pair_matrix = torch.cat( | ||||
|             ( | ||||
|  | @ -194,17 +190,17 @@ class SpanPredictor(torch.nn.Module): | |||
|         ) | ||||
|         lengths = same_sent.sum(dim=1) | ||||
|         padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) | ||||
|         padding_mask = padding_mask < lengths.unsqueeze(1)  # [n_heads, max_sent_len] | ||||
|         # [n_heads, max_sent_len, input_size * 2 + distance_emb_size] | ||||
|         # (n_heads x max_sent_len) | ||||
|         padding_mask = padding_mask < lengths.unsqueeze(1) | ||||
|         # (n_heads x max_sent_len x input_size * 2 + distance_emb_size) | ||||
|         # This is necessary to allow the convolution layer to look at several | ||||
|         # word scores | ||||
|         padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1]) | ||||
|         padded_pairs[padding_mask] = pair_matrix | ||||
| 
 | ||||
|         res = self.ffnn(padded_pairs)  # [n_heads, n_candidates, last_layer_output] | ||||
|         res = self.ffnn(padded_pairs)  # (n_heads x n_candidates x last_layer_output) | ||||
|         res = self.conv(res.permute(0, 2, 1)).permute( | ||||
|             0, 2, 1 | ||||
|         )  # [n_heads, n_candidates, 2] | ||||
|         )  # (n_heads x n_candidates, 2) | ||||
| 
 | ||||
|         scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf")) | ||||
|         scores[rows, cols] = res[padding_mask] | ||||
|  |  | |||
|  | @ -350,9 +350,7 @@ class CoreferenceResolver(TrainablePipe): | |||
| 
 | ||||
|     def score(self, examples, **kwargs): | ||||
|         """Score a batch of examples using LEA. | ||||
| 
 | ||||
|         For details on how LEA works and why to use it see the paper: | ||||
| 
 | ||||
|         Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric | ||||
|         Moosavi and Strube, 2016 | ||||
|         https://api.semanticscholar.org/CorpusID:17606580 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user