diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 4be22dd96..d59619498 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -46,11 +46,7 @@ def build_wl_coref_model( return coref_model -def convert_coref_clusterer_inputs( - model: Model, - X: List[Floats2d], - is_train: bool -): +def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): # The input here is List[Floats2d], one for each doc # just use the first # TODO real batching @@ -62,7 +58,7 @@ def convert_coref_clusterer_inputs( gradients = torch2xp(args.args[0]) return [gradients] - return ArgsKwargs(args=(word_features, ), kwargs={}), backprop + return ArgsKwargs(args=(word_features,), kwargs={}), backprop def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool): @@ -115,12 +111,7 @@ class CorefClusterer(torch.nn.Module): # Modules self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) pair_emb = dim * 3 + self.pw.shape - self.a_scorer = AnaphoricityScorer( - pair_emb, - hidden_size, - n_layers, - dropout - ) + self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout) self.lstm = torch.nn.LSTM( input_size=dim, hidden_size=dim, @@ -129,13 +120,9 @@ class CorefClusterer(torch.nn.Module): self.rough_scorer = RoughScorer(dim, dropout, roughk) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) pair_emb = dim * 3 + self.pw.shape - self.a_scorer = AnaphoricityScorer( - pair_emb, hidden_size, n_layers, dropout - ) + self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout) - def forward( - self, word_features: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ 1. LSTM encodes the incoming word_features. 2. The RoughScorer scores and prunes the candidates. @@ -350,13 +337,9 @@ class DistancePairwiseEncoder(torch.nn.Module): self.dropout = torch.nn.Dropout(dropout) self.shape = emb_size - def forward( - self, - top_indices: torch.Tensor - ) -> torch.Tensor: + def forward(self, top_indices: torch.Tensor) -> torch.Tensor: word_ids = torch.arange(0, top_indices.size(0)) - distance = (word_ids.unsqueeze(1) - word_ids[top_indices] - ).clamp_min_(min=1) + distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1) log_distance = distance.to(torch.float).log2().floor_() log_distance = log_distance.clamp_max_(max=6).to(torch.long) distance = torch.where(distance < 5, distance - 1, log_distance + 2) diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index 3e28ca8ee..8d0ff7bb0 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -7,6 +7,7 @@ MentionClusters = List[List[Tuple[int, int]]] DEFAULT_CLUSTER_PREFIX = "coref_clusters" + class GraphNode: def __init__(self, node_id: int): self.id = node_id @@ -30,6 +31,7 @@ def get_sentence_ids(doc): out.append(sent_id) return out + def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters: """Given a doc, give the mention clusters. @@ -100,7 +102,6 @@ def get_predicted_clusters( return predicted_clusters - def select_non_crossing_spans( idxs: List[int], starts: List[int], ends: List[int], limit: int ) -> List[int]: @@ -150,12 +151,14 @@ def select_non_crossing_spans( # selected.append(selected[0]) # this seems a bit weird? return selected + def create_head_span_idxs(ops, doclen: int): """Helper function to create single-token span indices.""" aa = ops.xp.arange(0, doclen) bb = ops.xp.arange(0, doclen) + 1 return ops.asarray2i([aa, bb]).T + def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: """Given a Doc, convert the cluster spans to simple int tuple lists.""" out = [] @@ -163,10 +166,10 @@ def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: cluster = [] for span in val: # TODO check that there isn't an off-by-one error here - #cluster.append((span.start, span.end)) + # cluster.append((span.start, span.end)) # TODO This conversion should be happening earlier in processing head_i = span.root.i - cluster.append( (head_i, head_i + 1) ) + cluster.append((head_i, head_i + 1)) # don't want duplicates cluster = list(set(cluster)) diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 03101edf9..a8c4d1aaa 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -18,7 +18,7 @@ def build_span_predictor( conv_channels: int = 4, window_size: int = 1, max_distance: int = 128, - prefix: str = "coref_head_clusters" + prefix: str = "coref_head_clusters", ): # TODO add model return types # TODO fix this @@ -36,7 +36,7 @@ def build_span_predictor( distance_embedding_size, conv_channels, window_size, - max_distance + max_distance, ), convert_inputs=convert_span_predictor_inputs, ) @@ -133,20 +133,17 @@ def head_data_forward(model, docs, is_train): # TODO this should maybe have a different name from the component class SpanPredictor(torch.nn.Module): def __init__( - self, - input_size: int, - hidden_size: int, - dist_emb_size: int, - conv_channels: int, - window_size: int, - max_distance: int - + self, + input_size: int, + hidden_size: int, + dist_emb_size: int, + conv_channels: int, + window_size: int, + max_distance: int, ): super().__init__() if max_distance % 2 != 0: - raise ValueError( - "max_distance has to be an even number" - ) + raise ValueError("max_distance has to be an even number") # input size = single token size # 64 = probably distance emb size # TODO check that dist_emb_size use is correct @@ -164,7 +161,7 @@ class SpanPredictor(torch.nn.Module): kernel_size = window_size * 2 + 1 self.conv = torch.nn.Sequential( torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1), - torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1) + torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1), ) # TODO make embeddings size a parameter self.max_distance = max_distance