diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index fd36c84f7..86afb028a 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -17,7 +17,11 @@ def build_coref( hidden: int = 1000, dropout: float = 0.3, mention_limit: int = 3900, + #TODO this needs a better name. It limits the max mentions as a ratio of + # the token count. + mention_limit_ratio: float = 0.4, max_span_width: int = 20, + antecedent_limit: int = 50 ): dim = tok2vec.get_dim("nO") * 3 @@ -42,8 +46,8 @@ def build_coref( (tok2vec & noop()) >> span_embedder >> (ms & noop()) - >> build_coarse_pruner(mention_limit) - >> build_ant_scorer(bilinear, Dropout(dropout)) + >> build_coarse_pruner(mention_limit, mention_limit_ratio) + >> build_ant_scorer(bilinear, Dropout(dropout), antecedent_limit) ) return model @@ -220,12 +224,14 @@ def span_embeddings_forward( def build_coarse_pruner( mention_limit: int, + mention_limit_ratio: float, ) -> Model[SpanEmbeddings, SpanEmbeddings]: model = Model( "CoarsePruner", forward=coarse_prune, attrs={ "mention_limit": mention_limit, + "mention_limit_ratio": mention_limit_ratio, }, ) return model @@ -241,6 +247,7 @@ def coarse_prune( rawscores, spanembeds = inputs scores = rawscores.flatten() mention_limit = model.attrs["mention_limit"] + mention_limit_ratio = model.attrs["mention_limit_ratio"] # XXX: Issue here. Don't need docs to find crossing spans, but might for the limits. # In old code the limit can be: # - hard number per doc @@ -258,8 +265,11 @@ def coarse_prune( starts = spanembeds.indices[offset:hi, 0].tolist() ends = spanembeds.indices[offset:hi:, 1].tolist() + # calculate the doc length + doclen = ends[-1] - starts[0] + mlimit = min(mention_limit, int(mention_limit_ratio * doclen)) # csel is a 1d integer list - csel = select_non_crossing_spans(tops, starts, ends, mention_limit) + csel = select_non_crossing_spans(tops, starts, ends, mlimit) # add the offset so these indices are absolute csel = [ii + offset for ii in csel] # this should be constant because short choices are padded diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index d065955c2..4caf02359 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -29,8 +29,10 @@ default_config = """ @architectures = "spacy.Coref.v1" max_span_width = 20 mention_limit = 3900 +mention_limit_ratio = 0.4 dropout = 0.3 hidden = 1000 +antecedent_limit = 50 [model.get_mentions] @misc = "spacy.CorefCandidateGenerator.v1"