Merge branch 'feature/coref' into fix/coref-alignment

This commit is contained in:
Paul O'Leary McCann 2022-07-11 19:14:37 +09:00 committed by GitHub
commit 6d9eafeb37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 110 additions and 65 deletions

View File

@ -1,8 +1,8 @@
from typing import List, Tuple
from typing import List, Tuple, Callable, cast
from thinc.api import Model, chain
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d
from thinc.types import Floats2d, Ints2d
from thinc.util import torch, xp2torch, torch2xp
from ...tokens import Doc
@ -23,8 +23,8 @@ def build_wl_coref_model(
antecedent_limit: int = 50,
antecedent_batch_size: int = 512,
tok2vec_size: int = 768, # tok2vec size
):
# TODO add model return types
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
with Model.define_operators({">>": chain}):
coref_clusterer = PyTorchWrapper(
@ -44,27 +44,24 @@ def build_wl_coref_model(
return coref_model
def convert_coref_clusterer_inputs(
model: Model, X: List[Floats2d], is_train: bool
):
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
# The input here is List[Floats2d], one for each doc
# just use the first
# TODO real batching
X = X[0]
word_features = xp2torch(X, requires_grad=is_train)
# TODO fix or remove type annotations
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[0])
gradients = cast(Floats2d, torch2xp(args.args[0]))
return [gradients]
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
def convert_coref_clusterer_outputs(
model: Model, inputs_outputs, is_train: bool
):
model: Model, inputs_outputs, is_train: bool
) -> Tuple[Tuple[Floats2d, Ints2d], Callable]:
_, outputs = inputs_outputs
scores, indices = outputs
@ -75,8 +72,8 @@ def convert_coref_clusterer_outputs(
kwargs={"grad_tensors": [dY_t]},
)
scores_xp = torch2xp(scores)
indices_xp = torch2xp(indices)
scores_xp = cast(Floats2d, torch2xp(scores))
indices_xp = cast(Ints2d, torch2xp(indices))
return (scores_xp, indices_xp), convert_for_torch_backward
@ -114,9 +111,7 @@ class CorefClusterer(torch.nn.Module):
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb, hidden_size, n_layers, dropout
)
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
self.lstm = torch.nn.LSTM(
input_size=dim,
hidden_size=dim,
@ -155,10 +150,10 @@ class CorefClusterer(torch.nn.Module):
a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size):
pw_batch = pw[i:i + batch_size]
words_batch = words[i:i + batch_size]
top_indices_batch = top_indices[i:i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
pw_batch = pw[i : i + batch_size]
words_batch = words[i : i + batch_size]
top_indices_batch = top_indices[i : i + batch_size]
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
# a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer(

View File

@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, cast
from thinc.api import Model, chain, tuplify
from thinc.api import PyTorchWrapper, ArgsKwargs
@ -35,7 +35,6 @@ def build_span_predictor(
),
convert_inputs=convert_span_predictor_inputs,
)
# TODO use proper parameter for prefix
head_info = build_get_head_metadata(prefix)
model = (tok2vec & head_info) >> span_predictor
@ -43,15 +42,17 @@ def build_span_predictor(
def convert_span_predictor_inputs(
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool
model: Model,
X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]],
is_train: bool,
):
tok2vec, (sent_ids, head_ids) = X
# Normally we should use the input is_train, but for these two it's not relevant
# TODO fix the type here, or remove it
def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]:
gradients = torch2xp(args.args[1])
def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]:
gradients = cast(Floats2d, torch2xp(args.args[1]))
# The sent_ids and head_ids are None because no gradients
return [[gradients], None]
return ([gradients], None)
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
@ -96,7 +97,6 @@ def predict_span_clusters(
def build_get_head_metadata(prefix):
# TODO this name is awful, fix it
model = Model(
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
)
@ -142,7 +142,6 @@ class SpanPredictor(torch.nn.Module):
raise ValueError("max_distance has to be an even number")
# input size = single token size
# 64 = probably distance emb size
# TODO check that dist_emb_size use is correct
self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
torch.nn.ReLU(),
@ -159,7 +158,6 @@ class SpanPredictor(torch.nn.Module):
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
)
# TODO make embeddings size a parameter
self.max_distance = max_distance
# handle distances between +-(max_distance - 2 / 2)
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
@ -211,9 +209,7 @@ class SpanPredictor(torch.nn.Module):
dim=1,
)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(
0, lengths.max().item(), device=device
).unsqueeze(0)
padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0)
# (n_heads x max_sent_len)
padding_mask = padding_mask < lengths.unsqueeze(1)
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)

View File

@ -95,7 +95,7 @@ def make_coref(
class CoreferenceResolver(TrainablePipe):
"""Pipeline component for coreference resolution.
DOCS: https://spacy.io/api/coref (TODO)
DOCS: https://spacy.io/api/coref
"""
def __init__(
@ -118,8 +118,10 @@ class CoreferenceResolver(TrainablePipe):
are stored in.
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
coref clusters in.
scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_coref_clusters.
DOCS: https://spacy.io/api/coref#init (TODO)
DOCS: https://spacy.io/api/coref#init
"""
self.vocab = vocab
self.model = model
@ -133,11 +135,12 @@ class CoreferenceResolver(TrainablePipe):
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
Return the list of predicted clusters.
docs (Iterable[Doc]): The documents to predict.
RETURNS: The models prediction for each document.
RETURNS (List[MentionClusters]): The model's prediction for each document.
DOCS: https://spacy.io/api/coref#predict (TODO)
DOCS: https://spacy.io/api/coref#predict
"""
out = []
for doc in docs:
@ -163,7 +166,7 @@ class CoreferenceResolver(TrainablePipe):
docs (Iterable[Doc]): The documents to modify.
clusters: The span clusters, produced by CoreferenceResolver.predict.
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
DOCS: https://spacy.io/api/coref#set_annotations
"""
docs = list(docs)
if len(docs) != len(clusters_by_doc):
@ -204,7 +207,7 @@ class CoreferenceResolver(TrainablePipe):
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/coref#update (TODO)
DOCS: https://spacy.io/api/coref#update
"""
if losses is None:
losses = {}
@ -225,12 +228,10 @@ class CoreferenceResolver(TrainablePipe):
predicted docs in coref training.
"""
)
# TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss
# TODO check shape here
backprop((d_scores, mention_idx))
if sgd is not None:
@ -239,7 +240,12 @@ class CoreferenceResolver(TrainablePipe):
return losses
def rehearse(self, examples, *, sgd=None, losses=None, **config):
raise NotImplementedError
# TODO this should be added later
raise NotImplementedError(
Errors.E931.format(
parent="CoreferenceResolver", method="add_label", name=self.name
)
)
def add_label(self, label: str) -> int:
"""Technically this method should be implemented from TrainablePipe,
@ -264,7 +270,7 @@ class CoreferenceResolver(TrainablePipe):
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/coref#get_loss (TODO)
DOCS: https://spacy.io/api/coref#get_loss
"""
ops = self.model.ops
xp = ops.xp
@ -289,9 +295,8 @@ class CoreferenceResolver(TrainablePipe):
span_idxs = create_head_span_idxs(ops, len(example.predicted))
gscores = create_gold_scores(span_idxs, clusters)
# TODO fix type here. This is bools but asarray2f wants ints.
# Note on type here. This is bools but asarray2f wants ints.
gscores = ops.asarray2f(gscores) # type: ignore
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
# now add the placeholder
gold_placeholder = ~top_gscores.any(axis=1).T
@ -323,7 +328,7 @@ class CoreferenceResolver(TrainablePipe):
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
DOCS: https://spacy.io/api/coref#initialize (TODO)
DOCS: https://spacy.io/api/coref#initialize
"""
validate_get_examples(get_examples, "CoreferenceResolver.initialize")

View File

@ -383,7 +383,7 @@ class EntityLinker(TrainablePipe):
no prediction.
docs (Iterable[Doc]): The documents to predict.
RETURNS (List[str]): The models prediction for each document.
RETURNS (List[str]): The model's prediction for each document.
DOCS: https://spacy.io/api/entitylinker#predict
"""

View File

@ -29,7 +29,7 @@ distance_embedding_size = 64
conv_channels = 4
window_size = 1
max_distance = 128
prefix = coref_head_clusters
prefix = "coref_head_clusters"
[model.tok2vec]
@architectures = "spacy.Tok2Vec.v2"
@ -95,6 +95,8 @@ class SpanPredictor(TrainablePipe):
"""Pipeline component to resolve one-token spans to full spans.
Used in coreference resolution.
DOCS: https://spacy.io/api/span_predictor
"""
def __init__(
@ -119,6 +121,14 @@ class SpanPredictor(TrainablePipe):
}
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
Return the list of predicted span clusters.
docs (Iterable[Doc]): The documents to predict.
RETURNS (List[MentionClusters]): The model's prediction for each document.
DOCS: https://spacy.io/api/span_predictor#predict
"""
# for now pretend there's just one doc
out = []
@ -151,6 +161,13 @@ class SpanPredictor(TrainablePipe):
return out
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
clusters: The span clusters, produced by SpanPredictor.predict.
DOCS: https://spacy.io/api/span_predictor#set_annotations
"""
for doc, clusters in zip(docs, clusters_by_doc):
for ii, cluster in enumerate(clusters):
spans = [doc[mm[0] : mm[1]] for mm in cluster]
@ -166,6 +183,15 @@ class SpanPredictor(TrainablePipe):
) -> Dict[str, float]:
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/span_predictor#update
"""
if losses is None:
losses = {}
@ -229,6 +255,15 @@ class SpanPredictor(TrainablePipe):
examples: Iterable[Example],
span_scores: Floats3d,
):
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
examples (Iterable[Examples]): The batch of examples.
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/span_predictor#get_loss
"""
ops = self.model.ops
# NOTE This is doing fake batching, and should always get a list of one example
@ -285,6 +320,15 @@ class SpanPredictor(TrainablePipe):
*,
nlp: Optional[Language] = None,
) -> None:
"""Initialize the pipe for training, using a representative set
of data examples.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
DOCS: https://spacy.io/api/span_predictor#initialize
"""
validate_get_examples(get_examples, "SpanPredictor.initialize")
X = []

View File

@ -587,8 +587,8 @@ consists of either two or three subnetworks:
run once for each batch.
- **lower**: Construct a feature-specific vector for each `(token, feature)`
pair. This is also run once for each batch. Constructing the state
representation is then a matter of summing the component features and
applying the non-linearity.
representation is then a matter of summing the component features and applying
the non-linearity.
- **upper** (optional): A feed-forward network that predicts scores from the
state representation. If not present, the output from the lower model is used
as action scores directly.
@ -628,8 +628,8 @@ same signature, but the `use_upper` argument was `True` by default.
> ```
Build a tagger model, using a provided token-to-vector component. The tagger
model adds a linear layer with softmax activation to predict scores given
the token vectors.
model adds a linear layer with softmax activation to predict scores given the
token vectors.
| Name | Description |
| ----------- | ------------------------------------------------------------------------------------------ |
@ -920,8 +920,8 @@ A function that reads an existing `KnowledgeBase` from file.
A function that takes as input a [`KnowledgeBase`](/api/kb) and a
[`Span`](/api/span) object denoting a named entity, and returns a list of
plausible [`Candidate`](/api/kb/#candidate) objects. The default
`CandidateGenerator` uses the text of a mention to find its potential
aliases in the `KnowledgeBase`. Note that this function is case-dependent.
`CandidateGenerator` uses the text of a mention to find its potential aliases in
the `KnowledgeBase`. Note that this function is case-dependent.
## Coreference Architectures
@ -975,7 +975,11 @@ The `Coref` model architecture is a Thinc `Model`.
> [model]
> @architectures = "spacy.SpanPredictor.v1"
> hidden_size = 1024
> dist_emb_size = 64
> distance_embedding_size = 64
> conv_channels = 4
> window_size = 1
> max_distance = 128
> prefix = "coref_head_clusters"
>
> [model.tok2vec]
> @architectures = "spacy-transformers.TransformerListener.v1"
@ -986,13 +990,14 @@ The `Coref` model architecture is a Thinc `Model`.
The `SpanPredictor` model architecture is a Thinc `Model`.
| Name | Description |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
| `hidden_size` | Size of the main internal layers. ~~int~~ |
| `depth` | Depth of the internal network. ~~int~~ |
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
| `antecedent_batch_size` | Internal batch size. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |
| Name | Description |
| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
| `hidden_size` | Size of the main internal layers. ~~int~~ |
| `conv_channels` | The number of channels in the internal CNN. ~~int~~ |
| `window_size` | The number of neighboring tokens to consider in the internal CNN. `1` means consider one token on each side. ~~int~~ |
| `max_distance` | The longest possible length of a predicted span. ~~int~~ |
| `prefix` | The prefix that indicates spans to use for input data. ~~string~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |