Fix accidentally quadratic runtime in Example.split_sents (#5464)

* Tidy up train-from-config a bit

* Fix accidentally quadratic perf in TokenAnnotation.brackets

When we're reading in the gold data, we had a nested loop where
we looped over the brackets for each token, looking for brackets
that start on that word. This is accidentally quadratic, because
we have one bracket per word (for the POS tags). So we had
an O(N**2) behaviour here that ended up being pretty slow.

To solve this I'm indexing the brackets by their starting word
on the TokenAnnotations object, and having a property to provide
the previous view.

* Fixes
This commit is contained in:
Matthew Honnibal 2020-05-20 18:48:18 +02:00 committed by GitHub
parent fda7355508
commit 609c0ba557
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 21 deletions

View File

@ -193,10 +193,11 @@ def train_from_config(
optimizer,
train_batches,
evaluate,
training["dropout"],
training["patience"],
training["eval_frequency"],
training["accumulate_gradient"]
dropout=training["dropout"],
accumulate_gradient=training["accumulate_gradient"],
patience=training.get("patience", 0),
max_steps=training.get("max_steps", 0),
eval_frequency=training["eval_frequency"],
)
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
@ -214,17 +215,17 @@ def train_from_config(
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
finally:
if output_path is not None:
with nlp.use_params(optimizer.averages):
final_model_path = output_path / "model-final"
final_model_path = output_path / "model-final"
if optimizer.averages:
with nlp.use_params(optimizer.averages):
nlp.to_disk(final_model_path)
else:
nlp.to_disk(final_model_path)
msg.good("Saved model to output directory", final_model_path)
# with msg.loading("Creating best model..."):
# best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
# msg.good("Created best model", best_model_path)
def create_train_batches(nlp, corpus, cfg):
is_first = True
epochs_todo = cfg.get("max_epochs", 0)
while True:
train_examples = list(corpus.train_dataset(
nlp,
@ -240,6 +241,11 @@ def create_train_batches(nlp, corpus, cfg):
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
for batch in batches:
yield batch
epochs_todo -= 1
# We intentionally compare exactly to 0 here, so that max_epochs < 1
# will not break.
if epochs_todo == 0:
break
def create_evaluation_callback(nlp, optimizer, corpus, cfg):
@ -270,8 +276,8 @@ def create_evaluation_callback(nlp, optimizer, corpus, cfg):
def train_while_improving(
nlp, optimizer, train_data, evaluate, dropout, patience, eval_frequency,
accumulate_gradient
nlp, optimizer, train_data, evaluate, *, dropout, eval_frequency,
accumulate_gradient=1, patience=0, max_steps=0
):
"""Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
@ -281,6 +287,7 @@ def train_while_improving(
Positional arguments:
nlp: The spaCy pipeline to evaluate.
optimizer: The optimizer callable.
train_data (Iterable[Batch]): A generator of batches, with the training
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
data iterable needs to take care of iterating over the epochs and
@ -344,9 +351,12 @@ def train_while_improving(
yield batch, info, is_best_checkpoint
if is_best_checkpoint is not None:
losses = {}
# Stop if no improvement in `patience` updates
# Stop if no improvement in `patience` updates (if specified)
best_score, best_step = max(results)
if (step - best_step) >= patience:
if patience and (step - best_step) >= patience:
break
# Stop if we've exhausted our max steps (if specified)
if max_steps and (step * accumulate_gradient) >= max_steps:
break

View File

@ -53,7 +53,7 @@ cdef class TokenAnnotation:
cdef public list deps
cdef public list entities
cdef public list sent_starts
cdef public list brackets
cdef public dict brackets_by_start
cdef class DocAnnotation:

View File

@ -658,7 +658,18 @@ cdef class TokenAnnotation:
self.deps = deps if deps else []
self.entities = entities if entities else []
self.sent_starts = sent_starts if sent_starts else []
self.brackets = brackets if brackets else []
self.brackets_by_start = {}
if brackets:
for b_start, b_end, b_label in brackets:
self.brackets_by_start.setdefault(b_start, []).append((b_end, b_label))
@property
def brackets(self):
brackets = []
for start, ends_labels in self.brackets_by_start.items():
for end, label in ends_labels:
brackets.append((start, end, label))
return brackets
@classmethod
def from_dict(cls, token_dict):
@ -811,8 +822,10 @@ cdef class Example:
s_lemmas, s_heads, s_deps, s_ents, s_sent_starts = [], [], [], [], []
s_brackets = []
sent_start_i = 0
t = self.token_annotation
cdef TokenAnnotation t = self.token_annotation
split_examples = []
cdef int b_start, b_end
cdef unicode b_label
for i in range(len(t.words)):
if i > 0 and t.sent_starts[i] == 1:
s_example.set_token_annotation(ids=s_ids,
@ -836,9 +849,10 @@ cdef class Example:
s_deps.append(t.get_dep(i))
s_ents.append(t.get_entity(i))
s_sent_starts.append(t.get_sent_start(i))
s_brackets.extend((b[0] - sent_start_i,
b[1] - sent_start_i, b[2])
for b in t.brackets if b[0] == i)
for b_end, b_label in t.brackets_by_start.get(i, []):
s_brackets.append(
(i - sent_start_i, b_end - sent_start_i, b_label)
)
i += 1
s_example.set_token_annotation(ids=s_ids, words=s_words, tags=s_tags,
pos=s_pos, morphs=s_morphs, lemmas=s_lemmas, heads=s_heads,
@ -904,8 +918,10 @@ cdef class Example:
examples = [examples]
converted_examples = []
for ex in examples:
if isinstance(ex, Example):
converted_examples.append(ex)
# convert string to Doc to Example
if isinstance(ex, str):
elif isinstance(ex, str):
if keep_raw_text:
converted_examples.append(Example(doc=ex))
else: