mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 12:42:20 +03:00
Forward/backward pass works
Evaluate does not work - predict hasn't been updated
This commit is contained in:
parent
c4f9c24738
commit
d22a002641
|
@ -4,7 +4,7 @@ import warnings
|
||||||
from thinc.api import Model, Linear, Relu, Dropout
|
from thinc.api import Model, Linear, Relu, Dropout
|
||||||
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
|
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
|
||||||
from thinc.api import reduce_first, reduce_last, reduce_mean
|
from thinc.api import reduce_first, reduce_last, reduce_mean
|
||||||
from thinc.api import PyTorchWrapper
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
|
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
|
||||||
from typing import List, Callable, Tuple, Any
|
from typing import List, Callable, Tuple, Any
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
@ -455,6 +455,7 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from thinc.util import xp2torch, torch2xp
|
||||||
|
|
||||||
# TODO rename this to coref_util
|
# TODO rename this to coref_util
|
||||||
from .coref_util_wl import add_dummy
|
from .coref_util_wl import add_dummy
|
||||||
|
@ -475,6 +476,7 @@ def build_wl_coref_model(
|
||||||
# span predictor embeddings
|
# span predictor embeddings
|
||||||
sp_embedding_size: int = 64,
|
sp_embedding_size: int = 64,
|
||||||
):
|
):
|
||||||
|
dim = tok2vec.get_dim("nO")
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
# TODO chain tok2vec with these models
|
# TODO chain tok2vec with these models
|
||||||
|
@ -483,6 +485,7 @@ def build_wl_coref_model(
|
||||||
coref_scorer = PyTorchWrapper(
|
coref_scorer = PyTorchWrapper(
|
||||||
CorefScorer(
|
CorefScorer(
|
||||||
device,
|
device,
|
||||||
|
dim,
|
||||||
embedding_size,
|
embedding_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
n_hidden_layers,
|
n_hidden_layers,
|
||||||
|
@ -513,11 +516,20 @@ def build_wl_coref_model(
|
||||||
|
|
||||||
def convert_coref_scorer_inputs(
|
def convert_coref_scorer_inputs(
|
||||||
model: Model,
|
model: Model,
|
||||||
X: Floats2d,
|
X: List[Floats2d],
|
||||||
is_train: bool
|
is_train: bool
|
||||||
):
|
):
|
||||||
word_features = xp2torch(X, requires_grad=False)
|
# The input here is List[Floats2d], one for each doc
|
||||||
return ArgsKwargs(args=(word_features, ), kwargs={}), lambda dX: []
|
# just use the first
|
||||||
|
# TODO real batching
|
||||||
|
X = X[0]
|
||||||
|
|
||||||
|
word_features = xp2torch(X, requires_grad=is_train)
|
||||||
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
|
# convert to xp and wrap in list
|
||||||
|
gradients = torch2xp(args.args[0])
|
||||||
|
return [gradients]
|
||||||
|
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_scorer_outputs(
|
def convert_coref_scorer_outputs(
|
||||||
|
@ -529,7 +541,7 @@ def convert_coref_scorer_outputs(
|
||||||
scores, indices = outputs
|
scores, indices = outputs
|
||||||
|
|
||||||
def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs:
|
def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs:
|
||||||
dY_t = xp2torch(dY)
|
dY_t = xp2torch(dY[0])
|
||||||
return ArgsKwargs(
|
return ArgsKwargs(
|
||||||
args=([scores],),
|
args=([scores],),
|
||||||
kwargs={"grad_tensors": [dY_t]},
|
kwargs={"grad_tensors": [dY_t]},
|
||||||
|
@ -633,6 +645,7 @@ class CorefScorer(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: str,
|
device: str,
|
||||||
|
dim: int, # tok2vec size
|
||||||
dist_emb_size: int,
|
dist_emb_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
n_layers: int,
|
n_layers: int,
|
||||||
|
@ -650,7 +663,8 @@ class CorefScorer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
# device, dist_emb_size, hidden_size, n_layers, dropout_rate
|
# device, dist_emb_size, hidden_size, n_layers, dropout_rate
|
||||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
|
||||||
bert_emb = 1024
|
#TODO clean this up
|
||||||
|
bert_emb = dim
|
||||||
pair_emb = bert_emb * 3 + self.pw.shape
|
pair_emb = bert_emb * 3 + self.pw.shape
|
||||||
self.a_scorer = AnaphoricityScorer(
|
self.a_scorer = AnaphoricityScorer(
|
||||||
pair_emb,
|
pair_emb,
|
||||||
|
|
|
@ -193,6 +193,11 @@ def select_non_crossing_spans(
|
||||||
# selected.append(selected[0]) # this seems a bit weird?
|
# selected.append(selected[0]) # this seems a bit weird?
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
def create_head_span_idxs(ops, doclen: int):
|
||||||
|
"""Helper function to create single-token span indices."""
|
||||||
|
aa = ops.xp.arange(0, doclen)
|
||||||
|
bb = ops.xp.arange(0, doclen) + 1
|
||||||
|
return ops.asarray2i([aa, bb]).T
|
||||||
|
|
||||||
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||||
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
||||||
|
@ -201,7 +206,13 @@ def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||||
cluster = []
|
cluster = []
|
||||||
for span in val:
|
for span in val:
|
||||||
# TODO check that there isn't an off-by-one error here
|
# TODO check that there isn't an off-by-one error here
|
||||||
cluster.append((span.start, span.end))
|
#cluster.append((span.start, span.end))
|
||||||
|
# TODO This conversion should be happening earlier in processing
|
||||||
|
head_i = span.root.i
|
||||||
|
cluster.append( (head_i, head_i + 1) )
|
||||||
|
|
||||||
|
# don't want duplicates
|
||||||
|
cluster = list(set(cluster))
|
||||||
out.append(cluster)
|
out.append(cluster)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -210,7 +221,11 @@ def create_gold_scores(
|
||||||
ments: Ints2d, clusters: List[List[Tuple[int, int]]]
|
ments: Ints2d, clusters: List[List[Tuple[int, int]]]
|
||||||
) -> List[List[bool]]:
|
) -> List[List[bool]]:
|
||||||
"""Given mentions considered for antecedents and gold clusters,
|
"""Given mentions considered for antecedents and gold clusters,
|
||||||
construct a gold score matrix. This does not include the placeholder."""
|
construct a gold score matrix. This does not include the placeholder.
|
||||||
|
|
||||||
|
In the gold matrix, the value of a true antecedent is True, and otherwise
|
||||||
|
it is False. These will be converted to 1/0 values later.
|
||||||
|
"""
|
||||||
# make a mapping of mentions to cluster id
|
# make a mapping of mentions to cluster id
|
||||||
# id is not important but equality will be
|
# id is not important but equality will be
|
||||||
ment2cid = {}
|
ment2cid = {}
|
||||||
|
|
|
@ -18,6 +18,7 @@ from ..vocab import Vocab
|
||||||
from ..ml.models.coref_util import (
|
from ..ml.models.coref_util import (
|
||||||
create_gold_scores,
|
create_gold_scores,
|
||||||
MentionClusters,
|
MentionClusters,
|
||||||
|
create_head_span_idxs,
|
||||||
get_clusters_from_doc,
|
get_clusters_from_doc,
|
||||||
get_predicted_clusters,
|
get_predicted_clusters,
|
||||||
DEFAULT_CLUSTER_PREFIX,
|
DEFAULT_CLUSTER_PREFIX,
|
||||||
|
@ -26,7 +27,8 @@ from ..ml.models.coref_util import (
|
||||||
|
|
||||||
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
|
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
|
||||||
|
|
||||||
default_config = """
|
# TODO remove this - kept for reference for now
|
||||||
|
old_default_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.Coref.v1"
|
@architectures = "spacy.Coref.v1"
|
||||||
max_span_width = 20
|
max_span_width = 20
|
||||||
|
@ -49,6 +51,35 @@ rows = [2000, 2000, 1000, 1000, 1000, 1000]
|
||||||
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
||||||
include_static_vectors = false
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[model.tok2vec.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = ${model.tok2vec.embed.width}
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
depth = 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_config = """
|
||||||
|
[model]
|
||||||
|
@architectures = "spacy.WLCoref.v1"
|
||||||
|
embedding_size = 20
|
||||||
|
hidden_size = 1024
|
||||||
|
n_hidden_layers = 1
|
||||||
|
dropout = 0.3
|
||||||
|
rough_k = 50
|
||||||
|
a_scoring_batch_size = 512
|
||||||
|
sp_embedding_size = 64
|
||||||
|
|
||||||
|
[model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = 64
|
||||||
|
rows = [2000, 2000, 1000, 1000, 1000, 1000]
|
||||||
|
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
[model.tok2vec.encode]
|
[model.tok2vec.encode]
|
||||||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
width = ${model.tok2vec.embed.width}
|
width = ${model.tok2vec.embed.width}
|
||||||
|
@ -210,7 +241,9 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
inputs = [example.predicted for example in examples]
|
inputs = [example.predicted for example in examples]
|
||||||
preds, backprop = self.model.begin_update(inputs)
|
preds, backprop = self.model.begin_update(inputs)
|
||||||
score_matrix, mention_idx = preds
|
score_matrix, mention_idx = preds
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(examples, score_matrix, mention_idx)
|
loss, d_scores = self.get_loss(examples, score_matrix, mention_idx)
|
||||||
|
# TODO check shape here
|
||||||
backprop((d_scores, mention_idx))
|
backprop((d_scores, mention_idx))
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
|
@ -292,15 +325,24 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
offset = 0
|
offset = 0
|
||||||
gradients = []
|
gradients = []
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
#TODO change this
|
||||||
|
# 1. do not handle batching (add it back later)
|
||||||
|
# 2. don't do index conversion (no mentions, just word indices)
|
||||||
|
# 3. convert words to spans (if necessary) in gold and predictions
|
||||||
|
|
||||||
|
# massage score matrix to be shaped correctly
|
||||||
|
score_matrix = [ (score_matrix, None) ]
|
||||||
for example, (cscores, cidx) in zip(examples, score_matrix):
|
for example, (cscores, cidx) in zip(examples, score_matrix):
|
||||||
|
|
||||||
ll = cscores.shape[0]
|
ll = cscores.shape[0]
|
||||||
hi = offset + ll
|
hi = offset + ll
|
||||||
|
|
||||||
clusters = get_clusters_from_doc(example.reference)
|
clusters = get_clusters_from_doc(example.reference)
|
||||||
gscores = create_gold_scores(mention_idx[offset:hi], clusters)
|
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||||
|
gscores = create_gold_scores(span_idxs, clusters)
|
||||||
gscores = ops.asarray2f(gscores)
|
gscores = ops.asarray2f(gscores)
|
||||||
top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
#top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||||
|
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||||
# now add the placeholder
|
# now add the placeholder
|
||||||
gold_placeholder = ~top_gscores.any(axis=1).T
|
gold_placeholder = ~top_gscores.any(axis=1).T
|
||||||
gold_placeholder = xp.expand_dims(gold_placeholder, 1)
|
gold_placeholder = xp.expand_dims(gold_placeholder, 1)
|
||||||
|
@ -319,6 +361,8 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
offset = hi
|
offset = hi
|
||||||
|
|
||||||
|
# Undo the wrapping
|
||||||
|
gradients = gradients[0][0]
|
||||||
return total_loss, gradients
|
return total_loss, gradients
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user