mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-20 18:54:21 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			446 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			446 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from dataclasses import dataclass
 | |
| import warnings
 | |
| 
 | |
| from thinc.api import Model, Linear, Relu, Dropout
 | |
| from thinc.api import chain, noop, Embed, add, tuplify, concatenate
 | |
| from thinc.api import reduce_first, reduce_last, reduce_mean
 | |
| from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
 | |
| from typing import List, Callable, Tuple, Any
 | |
| from ...tokens import Doc
 | |
| from ...util import registry
 | |
| from ..extract_spans import extract_spans
 | |
| 
 | |
| from .coref_util import get_candidate_mentions, select_non_crossing_spans, topk
 | |
| 
 | |
| 
 | |
| @registry.architectures("spacy.Coref.v1")
 | |
| def build_coref(
 | |
|     tok2vec: Model[List[Doc], List[Floats2d]],
 | |
|     get_mentions: Any = get_candidate_mentions,
 | |
|     hidden: int = 1000,
 | |
|     dropout: float = 0.3,
 | |
|     mention_limit: int = 3900,
 | |
|     #TODO this needs a better name. It limits the max mentions as a ratio of 
 | |
|     # the token count.
 | |
|     mention_limit_ratio: float = 0.4,
 | |
|     max_span_width: int = 20,
 | |
|     antecedent_limit: int = 50
 | |
| ):
 | |
|     dim = tok2vec.get_dim("nO") * 3
 | |
| 
 | |
|     span_embedder = build_span_embedder(get_mentions, max_span_width)
 | |
| 
 | |
|     with Model.define_operators({">>": chain, "&": tuplify, "+": add}):
 | |
| 
 | |
|         mention_scorer = (
 | |
|             Linear(nI=dim, nO=hidden)
 | |
|             >> Relu(nI=hidden, nO=hidden)
 | |
|             >> Dropout(dropout)
 | |
|             >> Linear(nI=hidden, nO=1)
 | |
|         )
 | |
|         mention_scorer.initialize()
 | |
| 
 | |
|         #TODO make feature_embed_size a param
 | |
|         feature_embed_size = 20
 | |
|         width_scorer = build_width_scorer(max_span_width, hidden, feature_embed_size)
 | |
| 
 | |
|         bilinear = Linear(nI=dim, nO=dim) >> Dropout(dropout)
 | |
|         bilinear.initialize()
 | |
| 
 | |
|         ms = (build_take_vecs() >> mention_scorer) + width_scorer
 | |
| 
 | |
|         model = (
 | |
|             (tok2vec & noop())
 | |
|             >> span_embedder
 | |
|             >> (ms & noop())
 | |
|             >> build_coarse_pruner(mention_limit, mention_limit_ratio)
 | |
|             >> build_ant_scorer(bilinear, Dropout(dropout), antecedent_limit)
 | |
|         )
 | |
|     return model
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class SpanEmbeddings:
 | |
|     indices: Ints2d  # Array with 2 columns (for start and end index)
 | |
|     vectors: Ragged  # Ragged[Floats2d] # One vector per span
 | |
|     # NB: We assume that the indices refer to a concatenated Floats2d that
 | |
|     # has one row per token in the *batch* of documents. This makes it unambiguous
 | |
|     # which row is in which document, because if the lengths are e.g. [10, 5],
 | |
|     # a span starting at 11 must be starting at token 2 of doc 1. A bug could
 | |
|     # potentially cause you to have a span which crosses a doc boundary though,
 | |
|     # which would be bad.
 | |
|     # The lengths in the Ragged are not the tokens per doc, but the number of
 | |
|     # mentions per doc.
 | |
| 
 | |
|     def __add__(self, right):
 | |
|         out = self.vectors.data + right.vectors.data
 | |
|         return SpanEmbeddings(self.indices, Ragged(out, self.vectors.lengths))
 | |
| 
 | |
|     def __iadd__(self, right):
 | |
|         self.vectors.data += right.vectors.data
 | |
|         return self
 | |
| 
 | |
| 
 | |
| def build_width_scorer(max_span_width, hidden_size, feature_embed_size=20):
 | |
|     span_width_prior = (
 | |
|         Embed(nV=max_span_width, nO=feature_embed_size)
 | |
|         >> Linear(nI=feature_embed_size, nO=hidden_size)
 | |
|         >> Relu(nI=hidden_size, nO=hidden_size)
 | |
|         >> Dropout()
 | |
|         >> Linear(nI=hidden_size, nO=1)
 | |
|     )
 | |
|     span_width_prior.initialize()
 | |
|     model = Model(
 | |
|             "WidthScorer",
 | |
|             forward=width_score_forward,
 | |
|             layers=[span_width_prior])
 | |
|     model.set_ref("width_prior", span_width_prior)
 | |
|     return model
 | |
| 
 | |
| 
 | |
| def width_score_forward(model, embeds: SpanEmbeddings, is_train) -> Tuple[Floats1d, Callable]:
 | |
|     # calculate widths, subtracting 1 so it's 0-index
 | |
|     w_ffnn = model.get_ref("width_prior")
 | |
|     idxs = embeds.indices
 | |
|     widths = idxs[:,1] - idxs[:,0] - 1
 | |
|     wscores, width_b = w_ffnn(widths, is_train)
 | |
| 
 | |
|     lens = embeds.vectors.lengths
 | |
| 
 | |
|     def width_score_backward(d_score: Floats1d) -> SpanEmbeddings:
 | |
| 
 | |
|         dX = width_b(d_score)
 | |
|         vecs = Ragged(dX, lens)
 | |
|         return SpanEmbeddings(idxs, vecs)
 | |
| 
 | |
|     return wscores, width_score_backward
 | |
| 
 | |
| # model converting a Doc/Mention to span embeddings
 | |
| # get_mentions: Callable[Doc, Pairs[int]]
 | |
| def build_span_embedder(
 | |
|     get_mentions: Callable,
 | |
|     max_span_width: int = 20,
 | |
| ) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
 | |
| 
 | |
|     with Model.define_operators({">>": chain, "|": concatenate}):
 | |
|         span_reduce = (extract_spans() >> 
 | |
|                 (reduce_first() | reduce_last() | reduce_mean()))
 | |
|     model = Model(
 | |
|         "SpanEmbedding",
 | |
|         forward=span_embeddings_forward,
 | |
|         attrs={
 | |
|             "get_mentions": get_mentions,
 | |
|             # XXX might be better to make this an implicit parameter in the
 | |
|             # mention generator
 | |
|             "max_span_width": max_span_width,
 | |
|         },
 | |
|         layers=[span_reduce],
 | |
|     )
 | |
|     model.set_ref("span_reducer", span_reduce)
 | |
|     return model
 | |
| 
 | |
| 
 | |
| def span_embeddings_forward(
 | |
|     model, inputs: Tuple[List[Floats2d], List[Doc]], is_train
 | |
| ) -> Tuple[SpanEmbeddings, Callable]:
 | |
|     ops = model.ops
 | |
|     xp = ops.xp
 | |
| 
 | |
|     tokvecs, docs = inputs
 | |
| 
 | |
|     # TODO fix this
 | |
|     dim = tokvecs[0].shape[1]
 | |
| 
 | |
|     get_mentions = model.attrs["get_mentions"]
 | |
|     max_span_width = model.attrs["max_span_width"]
 | |
|     mentions = ops.alloc2i(0, 2)
 | |
|     docmenlens = []  # number of mentions per doc
 | |
| 
 | |
|     for doc in docs:
 | |
|         starts, ends = get_mentions(doc, max_span_width)
 | |
|         docmenlens.append(len(starts))
 | |
|         cments = ops.asarray2i([starts, ends]).transpose()
 | |
| 
 | |
|         mentions = xp.concatenate( (mentions, cments) )
 | |
| 
 | |
|     # TODO support attention here
 | |
|     tokvecs = xp.concatenate(tokvecs)
 | |
|     doclens = [len(doc) for doc in docs]
 | |
|     tokvecs_r = Ragged(tokvecs, doclens)
 | |
|     mentions_r = Ragged(mentions, docmenlens)
 | |
| 
 | |
|     span_reduce = model.get_ref("span_reducer")
 | |
|     spanvecs, span_reduce_back = span_reduce( (tokvecs_r, mentions_r), is_train)
 | |
| 
 | |
|     embeds = Ragged(spanvecs, docmenlens)
 | |
| 
 | |
|     def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
 | |
|         grad, idxes = span_reduce_back(dY.vectors.data)
 | |
| 
 | |
|         oweights = []
 | |
|         offset = 0
 | |
|         for doclen in doclens:
 | |
|             hi = offset + doclen
 | |
|             oweights.append(grad.data[offset:hi])
 | |
|             offset = hi
 | |
| 
 | |
|         return oweights, docs
 | |
| 
 | |
|     return SpanEmbeddings(mentions, embeds), backprop_span_embed
 | |
| 
 | |
| 
 | |
| def build_coarse_pruner(
 | |
|     mention_limit: int,
 | |
|     mention_limit_ratio: float,
 | |
| ) -> Model[SpanEmbeddings, SpanEmbeddings]:
 | |
|     model = Model(
 | |
|         "CoarsePruner",
 | |
|         forward=coarse_prune,
 | |
|         attrs={
 | |
|             "mention_limit": mention_limit,
 | |
|             "mention_limit_ratio": mention_limit_ratio,
 | |
|         },
 | |
|     )
 | |
|     return model
 | |
| 
 | |
| 
 | |
| def coarse_prune(
 | |
|     model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
 | |
| ) -> Tuple[Tuple[Floats1d, SpanEmbeddings], Callable]:
 | |
|     """Given scores for mention, output the top non-crossing mentions.
 | |
| 
 | |
|     Mentions can contain other mentions, but candidate mentions cannot cross each other.
 | |
|     """
 | |
|     rawscores, spanembeds = inputs
 | |
|     scores = rawscores.flatten()
 | |
|     mention_limit = model.attrs["mention_limit"]
 | |
|     mention_limit_ratio = model.attrs["mention_limit_ratio"]
 | |
|     # XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
 | |
|     # In old code the limit can be:
 | |
|     # - hard number per doc
 | |
|     # - ratio of tokens in the doc
 | |
| 
 | |
|     offset = 0
 | |
|     selected = []
 | |
|     sellens = []
 | |
|     for menlen in spanembeds.vectors.lengths:
 | |
|         hi = offset + menlen
 | |
|         cscores = scores[offset:hi]
 | |
| 
 | |
|         # negate it so highest numbers come first
 | |
|         # This is relatively slow but can't be skipped.
 | |
|         tops = (model.ops.xp.argsort(-1 * cscores)).tolist()
 | |
|         starts = spanembeds.indices[offset:hi, 0].tolist()
 | |
|         ends = spanembeds.indices[offset:hi:, 1].tolist()
 | |
| 
 | |
|         # calculate the doc length
 | |
|         doclen = ends[-1] - starts[0]
 | |
|         # XXX seems to make more sense to use menlen than doclen here?
 | |
|         #mlimit = min(mention_limit, int(mention_limit_ratio * doclen))
 | |
|         mlimit = min(mention_limit, int(mention_limit_ratio * menlen))
 | |
|         # csel is a 1d integer list
 | |
|         csel = select_non_crossing_spans(tops, starts, ends, mlimit)
 | |
|         # add the offset so these indices are absolute
 | |
|         csel = [ii + offset for ii in csel]
 | |
|         # this should be constant because short choices are padded
 | |
|         sellens.append(len(csel))
 | |
|         selected += csel
 | |
|         offset += menlen
 | |
| 
 | |
|     selected = model.ops.asarray1i(selected)
 | |
|     top_spans = spanembeds.indices[selected]
 | |
|     top_vecs = spanembeds.vectors.data[selected]
 | |
| 
 | |
|     out = SpanEmbeddings(top_spans, Ragged(top_vecs, sellens))
 | |
| 
 | |
|     # save some variables so the embeds can be garbage collected
 | |
|     idxlen = spanembeds.indices.shape[0]
 | |
|     vecshape = spanembeds.vectors.data.shape
 | |
|     indices = spanembeds.indices
 | |
|     veclens = out.vectors.lengths
 | |
| 
 | |
|     def coarse_prune_backprop(
 | |
|         dY: Tuple[Floats1d, SpanEmbeddings]
 | |
|     ) -> Tuple[Floats1d, SpanEmbeddings]:
 | |
| 
 | |
|         dYscores, dYembeds = dY
 | |
| 
 | |
|         dXscores = model.ops.alloc1f(idxlen)
 | |
|         dXscores[selected] = dYscores.flatten()
 | |
| 
 | |
|         dXvecs = model.ops.alloc2f(*vecshape)
 | |
|         dXvecs[selected] = dYembeds.vectors.data
 | |
|         rout = Ragged(dXvecs, veclens)
 | |
|         dXembeds = SpanEmbeddings(indices, rout)
 | |
| 
 | |
|         # inflate for mention scorer
 | |
|         dXscores = model.ops.xp.expand_dims(dXscores, 1)
 | |
| 
 | |
|         return (dXscores, dXembeds)
 | |
| 
 | |
|     return (scores[selected], out), coarse_prune_backprop
 | |
| 
 | |
| 
 | |
| def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]:
 | |
|     # this just gets vectors out of spanembeddings
 | |
|     # XXX Might be better to convert SpanEmbeddings to a tuple and use with_getitem
 | |
|     return Model("TakeVecs", forward=take_vecs_forward)
 | |
| 
 | |
| 
 | |
| def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d:
 | |
|     idxs = inputs.indices
 | |
|     lens = inputs.vectors.lengths
 | |
|     def backprop(dY: Floats2d) -> SpanEmbeddings:
 | |
|         vecs = Ragged(dY, lens)
 | |
|         return SpanEmbeddings(idxs, vecs)
 | |
| 
 | |
|     return inputs.vectors.data, backprop
 | |
| 
 | |
| 
 | |
| def build_ant_scorer(
 | |
|     bilinear, dropout, ant_limit=50
 | |
| ) -> Model[Tuple[Floats1d, SpanEmbeddings], List[Floats2d]]:
 | |
|     model = Model(
 | |
|         "AntScorer",
 | |
|         forward=ant_scorer_forward,
 | |
|         layers=[bilinear, dropout],
 | |
|         attrs={
 | |
|             "ant_limit": ant_limit,
 | |
|         },
 | |
|     )
 | |
|     model.set_ref("bilinear", bilinear)
 | |
|     model.set_ref("dropout", dropout)
 | |
|     return model
 | |
| 
 | |
| 
 | |
| def ant_scorer_forward(
 | |
|     model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
 | |
| ) -> Tuple[Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d], Callable]:
 | |
|     ops = model.ops
 | |
|     xp = ops.xp
 | |
| 
 | |
|     ant_limit = model.attrs["ant_limit"]
 | |
|     # this contains the coarse bilinear in coref-hoi
 | |
|     # coarse bilinear is a single layer linear network
 | |
|     # TODO make these proper refs
 | |
|     bilinear = model.get_ref("bilinear")
 | |
|     dropout = model.get_ref("dropout")
 | |
| 
 | |
|     mscores, sembeds = inputs
 | |
|     vecs = sembeds.vectors  # ragged
 | |
| 
 | |
|     offset = 0
 | |
|     backprops = []
 | |
|     out = []
 | |
|     for ll in vecs.lengths:
 | |
|         hi = offset + ll
 | |
|         # each iteration is one doc
 | |
| 
 | |
|         # first calculate the pairwise product scores
 | |
|         cvecs = vecs.data[offset:hi]
 | |
|         pw_prod, prod_back = pairwise_product(bilinear, dropout, cvecs, is_train)
 | |
| 
 | |
|         # now calculate the pairwise mention scores
 | |
|         ms = mscores[offset:hi].flatten()
 | |
|         pw_sum, pw_sum_back = pairwise_sum(ops, ms)
 | |
| 
 | |
|         # make a mask so antecedents precede referrents
 | |
|         ant_range = xp.arange(0, cvecs.shape[0])
 | |
| 
 | |
|         # This will take the log of 0, which causes a warning, but we're doing
 | |
|         # it on purpose so we can just ignore the warning.
 | |
|         with warnings.catch_warnings():
 | |
|             warnings.filterwarnings('ignore', category=RuntimeWarning)
 | |
|             mask = xp.log(
 | |
|                 (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
 | |
|             ).astype('f')
 | |
| 
 | |
|         scores = pw_prod + pw_sum + mask
 | |
| 
 | |
|         top_limit = min(ant_limit, len(scores))
 | |
|         top_scores, top_scores_idx = topk(xp, scores, top_limit)
 | |
|         # now add the placeholder
 | |
|         placeholder = ops.alloc2f(scores.shape[0], 1)
 | |
|         top_scores = xp.concatenate( (placeholder, top_scores), 1)
 | |
| 
 | |
|         out.append((top_scores, top_scores_idx))
 | |
| 
 | |
|         # In the full model these scores can be further refined. In the current
 | |
|         # state of this model we're done here, so this pruning is less important,
 | |
|         # but it's still helpful for reducing memory usage (since scores can be
 | |
|         # garbage collected when the loop exits).
 | |
| 
 | |
|         offset += ll
 | |
|         backprops.append((prod_back, pw_sum_back))
 | |
| 
 | |
|     # save vars for gc
 | |
|     vecshape = vecs.data.shape
 | |
|     veclens = vecs.lengths
 | |
|     scoreshape = mscores.shape
 | |
|     idxes = sembeds.indices
 | |
| 
 | |
|     def backprop(
 | |
|         dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
 | |
|     ) -> Tuple[Floats2d, SpanEmbeddings]:
 | |
|         dYscores, dYembeds = dYs
 | |
|         dXembeds = Ragged(ops.alloc2f(*vecshape), veclens)
 | |
|         dXscores = ops.alloc1f(*scoreshape)
 | |
| 
 | |
|         offset = 0
 | |
|         for dy, (prod_back, pw_sum_back), ll in zip(dYscores, backprops, veclens):
 | |
|             hi = offset + ll
 | |
|             dyscore, dyidx = dy
 | |
|             # remove the placeholder
 | |
|             dyscore = dyscore[:, 1:]
 | |
|             # the full score grid is square
 | |
| 
 | |
|             fullscore = ops.alloc2f(ll, ll)
 | |
|             for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)):
 | |
|                 fullscore[ii][ridx] = rscores
 | |
| 
 | |
|             dXembeds.data[offset : hi] = prod_back(fullscore)
 | |
|             dXscores[offset : hi] = pw_sum_back(fullscore)
 | |
| 
 | |
|             offset = hi
 | |
|         # make it fit back into the linear
 | |
|         dXscores = xp.expand_dims(dXscores, 1)
 | |
|         return (dXscores, SpanEmbeddings(idxes, dXembeds))
 | |
| 
 | |
|     return (out, sembeds.indices), backprop
 | |
| 
 | |
| 
 | |
| def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
 | |
|     """Find the most likely mention-antecedent pairs."""
 | |
|     # This doesn't use multiplication because two items with low mention scores
 | |
|     # don't make a good candidate pair.
 | |
| 
 | |
|     pw_sum = ops.xp.expand_dims(mention_scores, 1) + ops.xp.expand_dims(
 | |
|         mention_scores, 0
 | |
|     )
 | |
| 
 | |
|     def backward(d_pwsum: Floats2d) -> Floats1d:
 | |
|         # For the backward pass, the gradient is distributed over the whole row and
 | |
|         # column, so pull it all in.
 | |
|         
 | |
|         out = d_pwsum.sum(axis=0) + d_pwsum.sum(axis=1)
 | |
| 
 | |
|         return out
 | |
| 
 | |
|     return pw_sum, backward
 | |
| 
 | |
| 
 | |
| def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
 | |
|     # A neat side effect of this is that we don't have to pass the backprops
 | |
|     # around separately because the closure handles them.
 | |
|     source, source_b = bilinear(vecs, is_train)
 | |
|     target, target_b = dropout(vecs.T, is_train)
 | |
|     pw_prod = source @ target
 | |
| 
 | |
|     def backward(d_prod: Floats2d) -> Floats2d:
 | |
|         dS = source_b(d_prod @ target.T)
 | |
|         dT = target_b(source.T @ d_prod)
 | |
|         dX = dS + dT.T
 | |
|         return dX
 | |
| 
 | |
|     return pw_prod, backward
 |