Expose more hyperparameters

This commit is contained in:
Paul O'Leary McCann 2021-06-17 21:21:46 +09:00
parent 848fd102e7
commit a62121e3b4
2 changed files with 15 additions and 3 deletions

View File

@ -17,7 +17,11 @@ def build_coref(
hidden: int = 1000, hidden: int = 1000,
dropout: float = 0.3, dropout: float = 0.3,
mention_limit: int = 3900, mention_limit: int = 3900,
#TODO this needs a better name. It limits the max mentions as a ratio of
# the token count.
mention_limit_ratio: float = 0.4,
max_span_width: int = 20, max_span_width: int = 20,
antecedent_limit: int = 50
): ):
dim = tok2vec.get_dim("nO") * 3 dim = tok2vec.get_dim("nO") * 3
@ -42,8 +46,8 @@ def build_coref(
(tok2vec & noop()) (tok2vec & noop())
>> span_embedder >> span_embedder
>> (ms & noop()) >> (ms & noop())
>> build_coarse_pruner(mention_limit) >> build_coarse_pruner(mention_limit, mention_limit_ratio)
>> build_ant_scorer(bilinear, Dropout(dropout)) >> build_ant_scorer(bilinear, Dropout(dropout), antecedent_limit)
) )
return model return model
@ -220,12 +224,14 @@ def span_embeddings_forward(
def build_coarse_pruner( def build_coarse_pruner(
mention_limit: int, mention_limit: int,
mention_limit_ratio: float,
) -> Model[SpanEmbeddings, SpanEmbeddings]: ) -> Model[SpanEmbeddings, SpanEmbeddings]:
model = Model( model = Model(
"CoarsePruner", "CoarsePruner",
forward=coarse_prune, forward=coarse_prune,
attrs={ attrs={
"mention_limit": mention_limit, "mention_limit": mention_limit,
"mention_limit_ratio": mention_limit_ratio,
}, },
) )
return model return model
@ -241,6 +247,7 @@ def coarse_prune(
rawscores, spanembeds = inputs rawscores, spanembeds = inputs
scores = rawscores.flatten() scores = rawscores.flatten()
mention_limit = model.attrs["mention_limit"] mention_limit = model.attrs["mention_limit"]
mention_limit_ratio = model.attrs["mention_limit_ratio"]
# XXX: Issue here. Don't need docs to find crossing spans, but might for the limits. # XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
# In old code the limit can be: # In old code the limit can be:
# - hard number per doc # - hard number per doc
@ -258,8 +265,11 @@ def coarse_prune(
starts = spanembeds.indices[offset:hi, 0].tolist() starts = spanembeds.indices[offset:hi, 0].tolist()
ends = spanembeds.indices[offset:hi:, 1].tolist() ends = spanembeds.indices[offset:hi:, 1].tolist()
# calculate the doc length
doclen = ends[-1] - starts[0]
mlimit = min(mention_limit, int(mention_limit_ratio * doclen))
# csel is a 1d integer list # csel is a 1d integer list
csel = select_non_crossing_spans(tops, starts, ends, mention_limit) csel = select_non_crossing_spans(tops, starts, ends, mlimit)
# add the offset so these indices are absolute # add the offset so these indices are absolute
csel = [ii + offset for ii in csel] csel = [ii + offset for ii in csel]
# this should be constant because short choices are padded # this should be constant because short choices are padded

View File

@ -29,8 +29,10 @@ default_config = """
@architectures = "spacy.Coref.v1" @architectures = "spacy.Coref.v1"
max_span_width = 20 max_span_width = 20
mention_limit = 3900 mention_limit = 3900
mention_limit_ratio = 0.4
dropout = 0.3 dropout = 0.3
hidden = 1000 hidden = 1000
antecedent_limit = 50
[model.get_mentions] [model.get_mentions]
@misc = "spacy.CorefCandidateGenerator.v1" @misc = "spacy.CorefCandidateGenerator.v1"