Expose more hyperparameters

This commit is contained in:
Paul O'Leary McCann 2021-06-17 21:21:46 +09:00
parent 848fd102e7
commit a62121e3b4
2 changed files with 15 additions and 3 deletions

View File

@ -17,7 +17,11 @@ def build_coref(
hidden: int = 1000,
dropout: float = 0.3,
mention_limit: int = 3900,
#TODO this needs a better name. It limits the max mentions as a ratio of
# the token count.
mention_limit_ratio: float = 0.4,
max_span_width: int = 20,
antecedent_limit: int = 50
):
dim = tok2vec.get_dim("nO") * 3
@ -42,8 +46,8 @@ def build_coref(
(tok2vec & noop())
>> span_embedder
>> (ms & noop())
>> build_coarse_pruner(mention_limit)
>> build_ant_scorer(bilinear, Dropout(dropout))
>> build_coarse_pruner(mention_limit, mention_limit_ratio)
>> build_ant_scorer(bilinear, Dropout(dropout), antecedent_limit)
)
return model
@ -220,12 +224,14 @@ def span_embeddings_forward(
def build_coarse_pruner(
mention_limit: int,
mention_limit_ratio: float,
) -> Model[SpanEmbeddings, SpanEmbeddings]:
model = Model(
"CoarsePruner",
forward=coarse_prune,
attrs={
"mention_limit": mention_limit,
"mention_limit_ratio": mention_limit_ratio,
},
)
return model
@ -241,6 +247,7 @@ def coarse_prune(
rawscores, spanembeds = inputs
scores = rawscores.flatten()
mention_limit = model.attrs["mention_limit"]
mention_limit_ratio = model.attrs["mention_limit_ratio"]
# XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
# In old code the limit can be:
# - hard number per doc
@ -258,8 +265,11 @@ def coarse_prune(
starts = spanembeds.indices[offset:hi, 0].tolist()
ends = spanembeds.indices[offset:hi:, 1].tolist()
# calculate the doc length
doclen = ends[-1] - starts[0]
mlimit = min(mention_limit, int(mention_limit_ratio * doclen))
# csel is a 1d integer list
csel = select_non_crossing_spans(tops, starts, ends, mention_limit)
csel = select_non_crossing_spans(tops, starts, ends, mlimit)
# add the offset so these indices are absolute
csel = [ii + offset for ii in csel]
# this should be constant because short choices are padded

View File

@ -29,8 +29,10 @@ default_config = """
@architectures = "spacy.Coref.v1"
max_span_width = 20
mention_limit = 3900
mention_limit_ratio = 0.4
dropout = 0.3
hidden = 1000
antecedent_limit = 50
[model.get_mentions]
@misc = "spacy.CorefCandidateGenerator.v1"