mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Fix accidentally quadratic runtime in Example.split_sents (#5464)
* Tidy up train-from-config a bit * Fix accidentally quadratic perf in TokenAnnotation.brackets When we're reading in the gold data, we had a nested loop where we looped over the brackets for each token, looking for brackets that start on that word. This is accidentally quadratic, because we have one bracket per word (for the POS tags). So we had an O(N**2) behaviour here that ended up being pretty slow. To solve this I'm indexing the brackets by their starting word on the TokenAnnotations object, and having a property to provide the previous view. * Fixes
This commit is contained in:
parent
fda7355508
commit
609c0ba557
|
@ -193,10 +193,11 @@ def train_from_config(
|
|||
optimizer,
|
||||
train_batches,
|
||||
evaluate,
|
||||
training["dropout"],
|
||||
training["patience"],
|
||||
training["eval_frequency"],
|
||||
training["accumulate_gradient"]
|
||||
dropout=training["dropout"],
|
||||
accumulate_gradient=training["accumulate_gradient"],
|
||||
patience=training.get("patience", 0),
|
||||
max_steps=training.get("max_steps", 0),
|
||||
eval_frequency=training["eval_frequency"],
|
||||
)
|
||||
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
|
@ -214,17 +215,17 @@ def train_from_config(
|
|||
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
|
||||
finally:
|
||||
if output_path is not None:
|
||||
with nlp.use_params(optimizer.averages):
|
||||
final_model_path = output_path / "model-final"
|
||||
final_model_path = output_path / "model-final"
|
||||
if optimizer.averages:
|
||||
with nlp.use_params(optimizer.averages):
|
||||
nlp.to_disk(final_model_path)
|
||||
else:
|
||||
nlp.to_disk(final_model_path)
|
||||
msg.good("Saved model to output directory", final_model_path)
|
||||
# with msg.loading("Creating best model..."):
|
||||
# best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
|
||||
# msg.good("Created best model", best_model_path)
|
||||
|
||||
|
||||
def create_train_batches(nlp, corpus, cfg):
|
||||
is_first = True
|
||||
epochs_todo = cfg.get("max_epochs", 0)
|
||||
while True:
|
||||
train_examples = list(corpus.train_dataset(
|
||||
nlp,
|
||||
|
@ -240,6 +241,11 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
|
||||
for batch in batches:
|
||||
yield batch
|
||||
epochs_todo -= 1
|
||||
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
||||
# will not break.
|
||||
if epochs_todo == 0:
|
||||
break
|
||||
|
||||
|
||||
def create_evaluation_callback(nlp, optimizer, corpus, cfg):
|
||||
|
@ -270,8 +276,8 @@ def create_evaluation_callback(nlp, optimizer, corpus, cfg):
|
|||
|
||||
|
||||
def train_while_improving(
|
||||
nlp, optimizer, train_data, evaluate, dropout, patience, eval_frequency,
|
||||
accumulate_gradient
|
||||
nlp, optimizer, train_data, evaluate, *, dropout, eval_frequency,
|
||||
accumulate_gradient=1, patience=0, max_steps=0
|
||||
):
|
||||
"""Train until an evaluation stops improving. Works as a generator,
|
||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||
|
@ -281,6 +287,7 @@ def train_while_improving(
|
|||
|
||||
Positional arguments:
|
||||
nlp: The spaCy pipeline to evaluate.
|
||||
optimizer: The optimizer callable.
|
||||
train_data (Iterable[Batch]): A generator of batches, with the training
|
||||
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
|
||||
data iterable needs to take care of iterating over the epochs and
|
||||
|
@ -344,9 +351,12 @@ def train_while_improving(
|
|||
yield batch, info, is_best_checkpoint
|
||||
if is_best_checkpoint is not None:
|
||||
losses = {}
|
||||
# Stop if no improvement in `patience` updates
|
||||
# Stop if no improvement in `patience` updates (if specified)
|
||||
best_score, best_step = max(results)
|
||||
if (step - best_step) >= patience:
|
||||
if patience and (step - best_step) >= patience:
|
||||
break
|
||||
# Stop if we've exhausted our max steps (if specified)
|
||||
if max_steps and (step * accumulate_gradient) >= max_steps:
|
||||
break
|
||||
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ cdef class TokenAnnotation:
|
|||
cdef public list deps
|
||||
cdef public list entities
|
||||
cdef public list sent_starts
|
||||
cdef public list brackets
|
||||
cdef public dict brackets_by_start
|
||||
|
||||
|
||||
cdef class DocAnnotation:
|
||||
|
|
|
@ -658,7 +658,18 @@ cdef class TokenAnnotation:
|
|||
self.deps = deps if deps else []
|
||||
self.entities = entities if entities else []
|
||||
self.sent_starts = sent_starts if sent_starts else []
|
||||
self.brackets = brackets if brackets else []
|
||||
self.brackets_by_start = {}
|
||||
if brackets:
|
||||
for b_start, b_end, b_label in brackets:
|
||||
self.brackets_by_start.setdefault(b_start, []).append((b_end, b_label))
|
||||
|
||||
@property
|
||||
def brackets(self):
|
||||
brackets = []
|
||||
for start, ends_labels in self.brackets_by_start.items():
|
||||
for end, label in ends_labels:
|
||||
brackets.append((start, end, label))
|
||||
return brackets
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, token_dict):
|
||||
|
@ -811,8 +822,10 @@ cdef class Example:
|
|||
s_lemmas, s_heads, s_deps, s_ents, s_sent_starts = [], [], [], [], []
|
||||
s_brackets = []
|
||||
sent_start_i = 0
|
||||
t = self.token_annotation
|
||||
cdef TokenAnnotation t = self.token_annotation
|
||||
split_examples = []
|
||||
cdef int b_start, b_end
|
||||
cdef unicode b_label
|
||||
for i in range(len(t.words)):
|
||||
if i > 0 and t.sent_starts[i] == 1:
|
||||
s_example.set_token_annotation(ids=s_ids,
|
||||
|
@ -836,9 +849,10 @@ cdef class Example:
|
|||
s_deps.append(t.get_dep(i))
|
||||
s_ents.append(t.get_entity(i))
|
||||
s_sent_starts.append(t.get_sent_start(i))
|
||||
s_brackets.extend((b[0] - sent_start_i,
|
||||
b[1] - sent_start_i, b[2])
|
||||
for b in t.brackets if b[0] == i)
|
||||
for b_end, b_label in t.brackets_by_start.get(i, []):
|
||||
s_brackets.append(
|
||||
(i - sent_start_i, b_end - sent_start_i, b_label)
|
||||
)
|
||||
i += 1
|
||||
s_example.set_token_annotation(ids=s_ids, words=s_words, tags=s_tags,
|
||||
pos=s_pos, morphs=s_morphs, lemmas=s_lemmas, heads=s_heads,
|
||||
|
@ -904,8 +918,10 @@ cdef class Example:
|
|||
examples = [examples]
|
||||
converted_examples = []
|
||||
for ex in examples:
|
||||
if isinstance(ex, Example):
|
||||
converted_examples.append(ex)
|
||||
# convert string to Doc to Example
|
||||
if isinstance(ex, str):
|
||||
elif isinstance(ex, str):
|
||||
if keep_raw_text:
|
||||
converted_examples.append(Example(doc=ex))
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user