diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 22234390e..fe7084fc5 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,8 +1,8 @@ -from typing import List, Tuple +from typing import List, Tuple, Callable, cast from thinc.api import Model, chain from thinc.api import PyTorchWrapper, ArgsKwargs -from thinc.types import Floats2d +from thinc.types import Floats2d, Ints2d from thinc.util import torch, xp2torch, torch2xp from ...tokens import Doc @@ -23,8 +23,8 @@ def build_wl_coref_model( antecedent_limit: int = 50, antecedent_batch_size: int = 512, tok2vec_size: int = 768, # tok2vec size -): - # TODO add model return types +) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]: + with Model.define_operators({">>": chain}): coref_clusterer = PyTorchWrapper( @@ -44,27 +44,24 @@ def build_wl_coref_model( return coref_model -def convert_coref_clusterer_inputs( - model: Model, X: List[Floats2d], is_train: bool -): +def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): # The input here is List[Floats2d], one for each doc # just use the first # TODO real batching X = X[0] word_features = xp2torch(X, requires_grad=is_train) - # TODO fix or remove type annotations - def backprop(args: ArgsKwargs): #-> List[Floats2d]: + def backprop(args: ArgsKwargs) -> List[Floats2d]: # convert to xp and wrap in list - gradients = torch2xp(args.args[0]) + gradients = cast(Floats2d, torch2xp(args.args[0])) return [gradients] return ArgsKwargs(args=(word_features,), kwargs={}), backprop def convert_coref_clusterer_outputs( - model: Model, inputs_outputs, is_train: bool -): + model: Model, inputs_outputs, is_train: bool +) -> Tuple[Tuple[Floats2d, Ints2d], Callable]: _, outputs = inputs_outputs scores, indices = outputs @@ -75,8 +72,8 @@ def convert_coref_clusterer_outputs( kwargs={"grad_tensors": [dY_t]}, ) - scores_xp = torch2xp(scores) - indices_xp = torch2xp(indices) + scores_xp = cast(Floats2d, torch2xp(scores)) + indices_xp = cast(Ints2d, torch2xp(indices)) return (scores_xp, indices_xp), convert_for_torch_backward @@ -114,9 +111,7 @@ class CorefClusterer(torch.nn.Module): self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) pair_emb = dim * 3 + self.pw.shape - self.a_scorer = AnaphoricityScorer( - pair_emb, hidden_size, n_layers, dropout - ) + self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout) self.lstm = torch.nn.LSTM( input_size=dim, hidden_size=dim, @@ -155,10 +150,10 @@ class CorefClusterer(torch.nn.Module): a_scores_lst: List[torch.Tensor] = [] for i in range(0, len(words), batch_size): - pw_batch = pw[i:i + batch_size] - words_batch = words[i:i + batch_size] - top_indices_batch = top_indices[i:i + batch_size] - top_rough_scores_batch = top_rough_scores[i:i + batch_size] + pw_batch = pw[i : i + batch_size] + words_batch = words[i : i + batch_size] + top_indices_batch = top_indices[i : i + batch_size] + top_rough_scores_batch = top_rough_scores[i : i + batch_size] # a_scores_batch [batch_size, n_ants] a_scores_batch = self.a_scorer( diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index d44e632bd..1947b7833 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, cast from thinc.api import Model, chain, tuplify from thinc.api import PyTorchWrapper, ArgsKwargs @@ -35,7 +35,6 @@ def build_span_predictor( ), convert_inputs=convert_span_predictor_inputs, ) - # TODO use proper parameter for prefix head_info = build_get_head_metadata(prefix) model = (tok2vec & head_info) >> span_predictor @@ -43,15 +42,17 @@ def build_span_predictor( def convert_span_predictor_inputs( - model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool + model: Model, + X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], + is_train: bool, ): tok2vec, (sent_ids, head_ids) = X # Normally we should use the input is_train, but for these two it's not relevant # TODO fix the type here, or remove it - def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]: - gradients = torch2xp(args.args[1]) + def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]: + gradients = cast(Floats2d, torch2xp(args.args[1])) # The sent_ids and head_ids are None because no gradients - return [[gradients], None] + return ([gradients], None) word_features = xp2torch(tok2vec[0], requires_grad=is_train) sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False) @@ -96,7 +97,6 @@ def predict_span_clusters( def build_get_head_metadata(prefix): - # TODO this name is awful, fix it model = Model( "HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward ) @@ -142,7 +142,6 @@ class SpanPredictor(torch.nn.Module): raise ValueError("max_distance has to be an even number") # input size = single token size # 64 = probably distance emb size - # TODO check that dist_emb_size use is correct self.ffnn = torch.nn.Sequential( torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size), torch.nn.ReLU(), @@ -159,7 +158,6 @@ class SpanPredictor(torch.nn.Module): torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1), torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1), ) - # TODO make embeddings size a parameter self.max_distance = max_distance # handle distances between +-(max_distance - 2 / 2) self.emb = torch.nn.Embedding(max_distance, dist_emb_size) @@ -211,9 +209,7 @@ class SpanPredictor(torch.nn.Module): dim=1, ) lengths = same_sent.sum(dim=1) - padding_mask = torch.arange( - 0, lengths.max().item(), device=device - ).unsqueeze(0) + padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0) # (n_heads x max_sent_len) padding_mask = padding_mask < lengths.unsqueeze(1) # (n_heads x max_sent_len x input_size * 2 + distance_emb_size) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index af40d9b06..7cf4fa44a 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -95,7 +95,7 @@ def make_coref( class CoreferenceResolver(TrainablePipe): """Pipeline component for coreference resolution. - DOCS: https://spacy.io/api/coref (TODO) + DOCS: https://spacy.io/api/coref """ def __init__( @@ -118,8 +118,10 @@ class CoreferenceResolver(TrainablePipe): are stored in. span_cluster_prefix (str): Prefix for the key in doc.spans to store the coref clusters in. + scorer (Optional[Callable]): The scoring method. Defaults to + Scorer.score_coref_clusters. - DOCS: https://spacy.io/api/coref#init (TODO) + DOCS: https://spacy.io/api/coref#init """ self.vocab = vocab self.model = model @@ -133,11 +135,12 @@ class CoreferenceResolver(TrainablePipe): def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: """Apply the pipeline's model to a batch of docs, without modifying them. + Return the list of predicted clusters. docs (Iterable[Doc]): The documents to predict. - RETURNS: The models prediction for each document. + RETURNS (List[MentionClusters]): The model's prediction for each document. - DOCS: https://spacy.io/api/coref#predict (TODO) + DOCS: https://spacy.io/api/coref#predict """ out = [] for doc in docs: @@ -163,7 +166,7 @@ class CoreferenceResolver(TrainablePipe): docs (Iterable[Doc]): The documents to modify. clusters: The span clusters, produced by CoreferenceResolver.predict. - DOCS: https://spacy.io/api/coref#set_annotations (TODO) + DOCS: https://spacy.io/api/coref#set_annotations """ docs = list(docs) if len(docs) != len(clusters_by_doc): @@ -204,7 +207,7 @@ class CoreferenceResolver(TrainablePipe): Updated using the component name as the key. RETURNS (Dict[str, float]): The updated losses dictionary. - DOCS: https://spacy.io/api/coref#update (TODO) + DOCS: https://spacy.io/api/coref#update """ if losses is None: losses = {} @@ -225,12 +228,10 @@ class CoreferenceResolver(TrainablePipe): predicted docs in coref training. """ ) - # TODO check this causes no issues (in practice it runs) preds, backprop = self.model.begin_update([eg.predicted]) score_matrix, mention_idx = preds loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) total_loss += loss - # TODO check shape here backprop((d_scores, mention_idx)) if sgd is not None: @@ -239,7 +240,12 @@ class CoreferenceResolver(TrainablePipe): return losses def rehearse(self, examples, *, sgd=None, losses=None, **config): - raise NotImplementedError + # TODO this should be added later + raise NotImplementedError( + Errors.E931.format( + parent="CoreferenceResolver", method="add_label", name=self.name + ) + ) def add_label(self, label: str) -> int: """Technically this method should be implemented from TrainablePipe, @@ -264,7 +270,7 @@ class CoreferenceResolver(TrainablePipe): scores: Scores representing the model's predictions. RETURNS (Tuple[float, float]): The loss and the gradient. - DOCS: https://spacy.io/api/coref#get_loss (TODO) + DOCS: https://spacy.io/api/coref#get_loss """ ops = self.model.ops xp = ops.xp @@ -289,9 +295,8 @@ class CoreferenceResolver(TrainablePipe): span_idxs = create_head_span_idxs(ops, len(example.predicted)) gscores = create_gold_scores(span_idxs, clusters) - # TODO fix type here. This is bools but asarray2f wants ints. + # Note on type here. This is bools but asarray2f wants ints. gscores = ops.asarray2f(gscores) # type: ignore - # top_gscores = xp.take_along_axis(gscores, cidx, axis=1) top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) # now add the placeholder gold_placeholder = ~top_gscores.any(axis=1).T @@ -323,7 +328,7 @@ class CoreferenceResolver(TrainablePipe): returns a representative sample of gold-standard Example objects. nlp (Language): The current nlp object the component is part of. - DOCS: https://spacy.io/api/coref#initialize (TODO) + DOCS: https://spacy.io/api/coref#initialize """ validate_get_examples(get_examples, "CoreferenceResolver.initialize") diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index aa7985a9c..36a291a88 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -383,7 +383,7 @@ class EntityLinker(TrainablePipe): no prediction. docs (Iterable[Doc]): The documents to predict. - RETURNS (List[str]): The models prediction for each document. + RETURNS (List[str]): The model's prediction for each document. DOCS: https://spacy.io/api/entitylinker#predict """ diff --git a/spacy/pipeline/span_predictor.py b/spacy/pipeline/span_predictor.py index aee11ba8e..6b43c1a56 100644 --- a/spacy/pipeline/span_predictor.py +++ b/spacy/pipeline/span_predictor.py @@ -29,7 +29,7 @@ distance_embedding_size = 64 conv_channels = 4 window_size = 1 max_distance = 128 -prefix = coref_head_clusters +prefix = "coref_head_clusters" [model.tok2vec] @architectures = "spacy.Tok2Vec.v2" @@ -95,6 +95,8 @@ class SpanPredictor(TrainablePipe): """Pipeline component to resolve one-token spans to full spans. Used in coreference resolution. + + DOCS: https://spacy.io/api/span_predictor """ def __init__( @@ -119,6 +121,14 @@ class SpanPredictor(TrainablePipe): } def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: + """Apply the pipeline's model to a batch of docs, without modifying them. + Return the list of predicted span clusters. + + docs (Iterable[Doc]): The documents to predict. + RETURNS (List[MentionClusters]): The model's prediction for each document. + + DOCS: https://spacy.io/api/span_predictor#predict + """ # for now pretend there's just one doc out = [] @@ -151,6 +161,13 @@ class SpanPredictor(TrainablePipe): return out def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: + """Modify a batch of Doc objects, using pre-computed scores. + + docs (Iterable[Doc]): The documents to modify. + clusters: The span clusters, produced by SpanPredictor.predict. + + DOCS: https://spacy.io/api/span_predictor#set_annotations + """ for doc, clusters in zip(docs, clusters_by_doc): for ii, cluster in enumerate(clusters): spans = [doc[mm[0] : mm[1]] for mm in cluster] @@ -166,6 +183,15 @@ class SpanPredictor(TrainablePipe): ) -> Dict[str, float]: """Learn from a batch of documents and gold-standard information, updating the pipe's model. Delegates to predict and get_loss. + + examples (Iterable[Example]): A batch of Example objects. + drop (float): The dropout rate. + sgd (thinc.api.Optimizer): The optimizer. + losses (Dict[str, float]): Optional record of the loss during training. + Updated using the component name as the key. + RETURNS (Dict[str, float]): The updated losses dictionary. + + DOCS: https://spacy.io/api/span_predictor#update """ if losses is None: losses = {} @@ -229,6 +255,15 @@ class SpanPredictor(TrainablePipe): examples: Iterable[Example], span_scores: Floats3d, ): + """Find the loss and gradient of loss for the batch of documents and + their predicted scores. + + examples (Iterable[Examples]): The batch of examples. + scores: Scores representing the model's predictions. + RETURNS (Tuple[float, float]): The loss and the gradient. + + DOCS: https://spacy.io/api/span_predictor#get_loss + """ ops = self.model.ops # NOTE This is doing fake batching, and should always get a list of one example @@ -285,6 +320,15 @@ class SpanPredictor(TrainablePipe): *, nlp: Optional[Language] = None, ) -> None: + """Initialize the pipe for training, using a representative set + of data examples. + + get_examples (Callable[[], Iterable[Example]]): Function that + returns a representative sample of gold-standard Example objects. + nlp (Language): The current nlp object the component is part of. + + DOCS: https://spacy.io/api/span_predictor#initialize + """ validate_get_examples(get_examples, "SpanPredictor.initialize") X = [] diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md index 4e70eee87..e881864a9 100644 --- a/website/docs/api/architectures.md +++ b/website/docs/api/architectures.md @@ -587,8 +587,8 @@ consists of either two or three subnetworks: run once for each batch. - **lower**: Construct a feature-specific vector for each `(token, feature)` pair. This is also run once for each batch. Constructing the state - representation is then a matter of summing the component features and - applying the non-linearity. + representation is then a matter of summing the component features and applying + the non-linearity. - **upper** (optional): A feed-forward network that predicts scores from the state representation. If not present, the output from the lower model is used as action scores directly. @@ -628,8 +628,8 @@ same signature, but the `use_upper` argument was `True` by default. > ``` Build a tagger model, using a provided token-to-vector component. The tagger -model adds a linear layer with softmax activation to predict scores given -the token vectors. +model adds a linear layer with softmax activation to predict scores given the +token vectors. | Name | Description | | ----------- | ------------------------------------------------------------------------------------------ | @@ -920,8 +920,8 @@ A function that reads an existing `KnowledgeBase` from file. A function that takes as input a [`KnowledgeBase`](/api/kb) and a [`Span`](/api/span) object denoting a named entity, and returns a list of plausible [`Candidate`](/api/kb/#candidate) objects. The default -`CandidateGenerator` uses the text of a mention to find its potential -aliases in the `KnowledgeBase`. Note that this function is case-dependent. +`CandidateGenerator` uses the text of a mention to find its potential aliases in +the `KnowledgeBase`. Note that this function is case-dependent. ## Coreference Architectures @@ -975,7 +975,11 @@ The `Coref` model architecture is a Thinc `Model`. > [model] > @architectures = "spacy.SpanPredictor.v1" > hidden_size = 1024 -> dist_emb_size = 64 +> distance_embedding_size = 64 +> conv_channels = 4 +> window_size = 1 +> max_distance = 128 +> prefix = "coref_head_clusters" > > [model.tok2vec] > @architectures = "spacy-transformers.TransformerListener.v1" @@ -986,13 +990,14 @@ The `Coref` model architecture is a Thinc `Model`. The `SpanPredictor` model architecture is a Thinc `Model`. -| Name | Description | -| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | -| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ | -| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ | -| `hidden_size` | Size of the main internal layers. ~~int~~ | -| `depth` | Depth of the internal network. ~~int~~ | -| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ | -| `antecedent_batch_size` | Internal batch size. ~~int~~ | -| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ | +| Name | Description | +| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------- | +| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | +| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ | +| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ | +| `hidden_size` | Size of the main internal layers. ~~int~~ | +| `conv_channels` | The number of channels in the internal CNN. ~~int~~ | +| `window_size` | The number of neighboring tokens to consider in the internal CNN. `1` means consider one token on each side. ~~int~~ | +| `max_distance` | The longest possible length of a predicted span. ~~int~~ | +| `prefix` | The prefix that indicates spans to use for input data. ~~string~~ | +| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |