mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-17 19:52:18 +03:00
Merge branch 'feature/coref' into fix/coref-alignment
This commit is contained in:
commit
6d9eafeb37
|
@ -1,8 +1,8 @@
|
|||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Callable, cast
|
||||
|
||||
from thinc.api import Model, chain
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
from thinc.types import Floats2d
|
||||
from thinc.types import Floats2d, Ints2d
|
||||
from thinc.util import torch, xp2torch, torch2xp
|
||||
|
||||
from ...tokens import Doc
|
||||
|
@ -23,8 +23,8 @@ def build_wl_coref_model(
|
|||
antecedent_limit: int = 50,
|
||||
antecedent_batch_size: int = 512,
|
||||
tok2vec_size: int = 768, # tok2vec size
|
||||
):
|
||||
# TODO add model return types
|
||||
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
|
||||
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
coref_clusterer = PyTorchWrapper(
|
||||
|
@ -44,27 +44,24 @@ def build_wl_coref_model(
|
|||
return coref_model
|
||||
|
||||
|
||||
def convert_coref_clusterer_inputs(
|
||||
model: Model, X: List[Floats2d], is_train: bool
|
||||
):
|
||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||
# The input here is List[Floats2d], one for each doc
|
||||
# just use the first
|
||||
# TODO real batching
|
||||
X = X[0]
|
||||
word_features = xp2torch(X, requires_grad=is_train)
|
||||
|
||||
# TODO fix or remove type annotations
|
||||
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[0])
|
||||
gradients = cast(Floats2d, torch2xp(args.args[0]))
|
||||
return [gradients]
|
||||
|
||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||
|
||||
|
||||
def convert_coref_clusterer_outputs(
|
||||
model: Model, inputs_outputs, is_train: bool
|
||||
):
|
||||
model: Model, inputs_outputs, is_train: bool
|
||||
) -> Tuple[Tuple[Floats2d, Ints2d], Callable]:
|
||||
_, outputs = inputs_outputs
|
||||
scores, indices = outputs
|
||||
|
||||
|
@ -75,8 +72,8 @@ def convert_coref_clusterer_outputs(
|
|||
kwargs={"grad_tensors": [dY_t]},
|
||||
)
|
||||
|
||||
scores_xp = torch2xp(scores)
|
||||
indices_xp = torch2xp(indices)
|
||||
scores_xp = cast(Floats2d, torch2xp(scores))
|
||||
indices_xp = cast(Ints2d, torch2xp(indices))
|
||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||
|
||||
|
||||
|
@ -114,9 +111,7 @@ class CorefClusterer(torch.nn.Module):
|
|||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||
|
||||
pair_emb = dim * 3 + self.pw.shape
|
||||
self.a_scorer = AnaphoricityScorer(
|
||||
pair_emb, hidden_size, n_layers, dropout
|
||||
)
|
||||
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
||||
self.lstm = torch.nn.LSTM(
|
||||
input_size=dim,
|
||||
hidden_size=dim,
|
||||
|
@ -155,10 +150,10 @@ class CorefClusterer(torch.nn.Module):
|
|||
a_scores_lst: List[torch.Tensor] = []
|
||||
|
||||
for i in range(0, len(words), batch_size):
|
||||
pw_batch = pw[i:i + batch_size]
|
||||
words_batch = words[i:i + batch_size]
|
||||
top_indices_batch = top_indices[i:i + batch_size]
|
||||
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
|
||||
pw_batch = pw[i : i + batch_size]
|
||||
words_batch = words[i : i + batch_size]
|
||||
top_indices_batch = top_indices[i : i + batch_size]
|
||||
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
|
||||
|
||||
# a_scores_batch [batch_size, n_ants]
|
||||
a_scores_batch = self.a_scorer(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Tuple
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
from thinc.api import Model, chain, tuplify
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
|
@ -35,7 +35,6 @@ def build_span_predictor(
|
|||
),
|
||||
convert_inputs=convert_span_predictor_inputs,
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
head_info = build_get_head_metadata(prefix)
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
|
||||
|
@ -43,15 +42,17 @@ def build_span_predictor(
|
|||
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool
|
||||
model: Model,
|
||||
X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]],
|
||||
is_train: bool,
|
||||
):
|
||||
tok2vec, (sent_ids, head_ids) = X
|
||||
# Normally we should use the input is_train, but for these two it's not relevant
|
||||
# TODO fix the type here, or remove it
|
||||
def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]:
|
||||
gradients = torch2xp(args.args[1])
|
||||
def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]:
|
||||
gradients = cast(Floats2d, torch2xp(args.args[1]))
|
||||
# The sent_ids and head_ids are None because no gradients
|
||||
return [[gradients], None]
|
||||
return ([gradients], None)
|
||||
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
|
||||
|
@ -96,7 +97,6 @@ def predict_span_clusters(
|
|||
|
||||
|
||||
def build_get_head_metadata(prefix):
|
||||
# TODO this name is awful, fix it
|
||||
model = Model(
|
||||
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
|
||||
)
|
||||
|
@ -142,7 +142,6 @@ class SpanPredictor(torch.nn.Module):
|
|||
raise ValueError("max_distance has to be an even number")
|
||||
# input size = single token size
|
||||
# 64 = probably distance emb size
|
||||
# TODO check that dist_emb_size use is correct
|
||||
self.ffnn = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
|
@ -159,7 +158,6 @@ class SpanPredictor(torch.nn.Module):
|
|||
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
||||
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
|
||||
)
|
||||
# TODO make embeddings size a parameter
|
||||
self.max_distance = max_distance
|
||||
# handle distances between +-(max_distance - 2 / 2)
|
||||
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
|
||||
|
@ -211,9 +209,7 @@ class SpanPredictor(torch.nn.Module):
|
|||
dim=1,
|
||||
)
|
||||
lengths = same_sent.sum(dim=1)
|
||||
padding_mask = torch.arange(
|
||||
0, lengths.max().item(), device=device
|
||||
).unsqueeze(0)
|
||||
padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0)
|
||||
# (n_heads x max_sent_len)
|
||||
padding_mask = padding_mask < lengths.unsqueeze(1)
|
||||
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
|
||||
|
|
|
@ -95,7 +95,7 @@ def make_coref(
|
|||
class CoreferenceResolver(TrainablePipe):
|
||||
"""Pipeline component for coreference resolution.
|
||||
|
||||
DOCS: https://spacy.io/api/coref (TODO)
|
||||
DOCS: https://spacy.io/api/coref
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -118,8 +118,10 @@ class CoreferenceResolver(TrainablePipe):
|
|||
are stored in.
|
||||
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
|
||||
coref clusters in.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_coref_clusters.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#init (TODO)
|
||||
DOCS: https://spacy.io/api/coref#init
|
||||
"""
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
|
@ -133,11 +135,12 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
Return the list of predicted clusters.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS: The models prediction for each document.
|
||||
RETURNS (List[MentionClusters]): The model's prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#predict (TODO)
|
||||
DOCS: https://spacy.io/api/coref#predict
|
||||
"""
|
||||
out = []
|
||||
for doc in docs:
|
||||
|
@ -163,7 +166,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
docs (Iterable[Doc]): The documents to modify.
|
||||
clusters: The span clusters, produced by CoreferenceResolver.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
|
||||
DOCS: https://spacy.io/api/coref#set_annotations
|
||||
"""
|
||||
docs = list(docs)
|
||||
if len(docs) != len(clusters_by_doc):
|
||||
|
@ -204,7 +207,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#update (TODO)
|
||||
DOCS: https://spacy.io/api/coref#update
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
|
@ -225,12 +228,10 @@ class CoreferenceResolver(TrainablePipe):
|
|||
predicted docs in coref training.
|
||||
"""
|
||||
)
|
||||
# TODO check this causes no issues (in practice it runs)
|
||||
preds, backprop = self.model.begin_update([eg.predicted])
|
||||
score_matrix, mention_idx = preds
|
||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||
total_loss += loss
|
||||
# TODO check shape here
|
||||
backprop((d_scores, mention_idx))
|
||||
|
||||
if sgd is not None:
|
||||
|
@ -239,7 +240,12 @@ class CoreferenceResolver(TrainablePipe):
|
|||
return losses
|
||||
|
||||
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
||||
raise NotImplementedError
|
||||
# TODO this should be added later
|
||||
raise NotImplementedError(
|
||||
Errors.E931.format(
|
||||
parent="CoreferenceResolver", method="add_label", name=self.name
|
||||
)
|
||||
)
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
"""Technically this method should be implemented from TrainablePipe,
|
||||
|
@ -264,7 +270,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
scores: Scores representing the model's predictions.
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#get_loss (TODO)
|
||||
DOCS: https://spacy.io/api/coref#get_loss
|
||||
"""
|
||||
ops = self.model.ops
|
||||
xp = ops.xp
|
||||
|
@ -289,9 +295,8 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||
gscores = create_gold_scores(span_idxs, clusters)
|
||||
# TODO fix type here. This is bools but asarray2f wants ints.
|
||||
# Note on type here. This is bools but asarray2f wants ints.
|
||||
gscores = ops.asarray2f(gscores) # type: ignore
|
||||
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||
# now add the placeholder
|
||||
gold_placeholder = ~top_gscores.any(axis=1).T
|
||||
|
@ -323,7 +328,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#initialize (TODO)
|
||||
DOCS: https://spacy.io/api/coref#initialize
|
||||
"""
|
||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||
|
||||
|
|
|
@ -383,7 +383,7 @@ class EntityLinker(TrainablePipe):
|
|||
no prediction.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS (List[str]): The models prediction for each document.
|
||||
RETURNS (List[str]): The model's prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#predict
|
||||
"""
|
||||
|
|
|
@ -29,7 +29,7 @@ distance_embedding_size = 64
|
|||
conv_channels = 4
|
||||
window_size = 1
|
||||
max_distance = 128
|
||||
prefix = coref_head_clusters
|
||||
prefix = "coref_head_clusters"
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
|
@ -95,6 +95,8 @@ class SpanPredictor(TrainablePipe):
|
|||
"""Pipeline component to resolve one-token spans to full spans.
|
||||
|
||||
Used in coreference resolution.
|
||||
|
||||
DOCS: https://spacy.io/api/span_predictor
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -119,6 +121,14 @@ class SpanPredictor(TrainablePipe):
|
|||
}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
Return the list of predicted span clusters.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS (List[MentionClusters]): The model's prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/span_predictor#predict
|
||||
"""
|
||||
# for now pretend there's just one doc
|
||||
|
||||
out = []
|
||||
|
@ -151,6 +161,13 @@ class SpanPredictor(TrainablePipe):
|
|||
return out
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||
"""Modify a batch of Doc objects, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
clusters: The span clusters, produced by SpanPredictor.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/span_predictor#set_annotations
|
||||
"""
|
||||
for doc, clusters in zip(docs, clusters_by_doc):
|
||||
for ii, cluster in enumerate(clusters):
|
||||
spans = [doc[mm[0] : mm[1]] for mm in cluster]
|
||||
|
@ -166,6 +183,15 @@ class SpanPredictor(TrainablePipe):
|
|||
) -> Dict[str, float]:
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/span_predictor#update
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
|
@ -229,6 +255,15 @@ class SpanPredictor(TrainablePipe):
|
|||
examples: Iterable[Example],
|
||||
span_scores: Floats3d,
|
||||
):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
||||
examples (Iterable[Examples]): The batch of examples.
|
||||
scores: Scores representing the model's predictions.
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/span_predictor#get_loss
|
||||
"""
|
||||
ops = self.model.ops
|
||||
|
||||
# NOTE This is doing fake batching, and should always get a list of one example
|
||||
|
@ -285,6 +320,15 @@ class SpanPredictor(TrainablePipe):
|
|||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
) -> None:
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
|
||||
DOCS: https://spacy.io/api/span_predictor#initialize
|
||||
"""
|
||||
validate_get_examples(get_examples, "SpanPredictor.initialize")
|
||||
|
||||
X = []
|
||||
|
|
|
@ -587,8 +587,8 @@ consists of either two or three subnetworks:
|
|||
run once for each batch.
|
||||
- **lower**: Construct a feature-specific vector for each `(token, feature)`
|
||||
pair. This is also run once for each batch. Constructing the state
|
||||
representation is then a matter of summing the component features and
|
||||
applying the non-linearity.
|
||||
representation is then a matter of summing the component features and applying
|
||||
the non-linearity.
|
||||
- **upper** (optional): A feed-forward network that predicts scores from the
|
||||
state representation. If not present, the output from the lower model is used
|
||||
as action scores directly.
|
||||
|
@ -628,8 +628,8 @@ same signature, but the `use_upper` argument was `True` by default.
|
|||
> ```
|
||||
|
||||
Build a tagger model, using a provided token-to-vector component. The tagger
|
||||
model adds a linear layer with softmax activation to predict scores given
|
||||
the token vectors.
|
||||
model adds a linear layer with softmax activation to predict scores given the
|
||||
token vectors.
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------------------------------------------------------ |
|
||||
|
@ -920,8 +920,8 @@ A function that reads an existing `KnowledgeBase` from file.
|
|||
A function that takes as input a [`KnowledgeBase`](/api/kb) and a
|
||||
[`Span`](/api/span) object denoting a named entity, and returns a list of
|
||||
plausible [`Candidate`](/api/kb/#candidate) objects. The default
|
||||
`CandidateGenerator` uses the text of a mention to find its potential
|
||||
aliases in the `KnowledgeBase`. Note that this function is case-dependent.
|
||||
`CandidateGenerator` uses the text of a mention to find its potential aliases in
|
||||
the `KnowledgeBase`. Note that this function is case-dependent.
|
||||
|
||||
## Coreference Architectures
|
||||
|
||||
|
@ -975,7 +975,11 @@ The `Coref` model architecture is a Thinc `Model`.
|
|||
> [model]
|
||||
> @architectures = "spacy.SpanPredictor.v1"
|
||||
> hidden_size = 1024
|
||||
> dist_emb_size = 64
|
||||
> distance_embedding_size = 64
|
||||
> conv_channels = 4
|
||||
> window_size = 1
|
||||
> max_distance = 128
|
||||
> prefix = "coref_head_clusters"
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> @architectures = "spacy-transformers.TransformerListener.v1"
|
||||
|
@ -986,13 +990,14 @@ The `Coref` model architecture is a Thinc `Model`.
|
|||
|
||||
The `SpanPredictor` model architecture is a Thinc `Model`.
|
||||
|
||||
| Name | Description |
|
||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
||||
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
|
||||
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
|
||||
| `hidden_size` | Size of the main internal layers. ~~int~~ |
|
||||
| `depth` | Depth of the internal network. ~~int~~ |
|
||||
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
|
||||
| `antecedent_batch_size` | Internal batch size. ~~int~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |
|
||||
| Name | Description |
|
||||
| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
||||
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
|
||||
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
|
||||
| `hidden_size` | Size of the main internal layers. ~~int~~ |
|
||||
| `conv_channels` | The number of channels in the internal CNN. ~~int~~ |
|
||||
| `window_size` | The number of neighboring tokens to consider in the internal CNN. `1` means consider one token on each side. ~~int~~ |
|
||||
| `max_distance` | The longest possible length of a predicted span. ~~int~~ |
|
||||
| `prefix` | The prefix that indicates spans to use for input data. ~~string~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |
|
||||
|
|
Loading…
Reference in New Issue
Block a user