cleanup components API (#5726)

* add keyword separator for update functions and drop unused "state"

* few more Example tests and various small fixes

* consistently return losses after update call

* eliminate unused tensors field across pipe components

* fix name

* fix arg name
This commit is contained in:
Sofie Van Landeghem 2020-07-09 19:43:39 +02:00 committed by GitHub
parent ac4297ee39
commit dd207a28be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 109 additions and 109 deletions

View File

@ -69,6 +69,9 @@ class Warnings(object):
W027 = ("Found a large training file of {size} bytes. Note that it may " W027 = ("Found a large training file of {size} bytes. Note that it may "
"be more efficient to split your training data into multiple " "be more efficient to split your training data into multiple "
"smaller JSON files instead.") "smaller JSON files instead.")
W028 = ("Doc.from_array was called with a vector of type '{type}', "
"but is expecting one of type 'uint64' instead. This may result "
"in problems with the vocab further on in the pipeline.")
W030 = ("Some entities could not be aligned in the text \"{text}\" with " W030 = ("Some entities could not be aligned in the text \"{text}\" with "
"entities \"{entities}\". Use " "entities \"{entities}\". Use "
"`spacy.gold.biluo_tags_from_offsets(nlp.make_doc(text), entities)`" "`spacy.gold.biluo_tags_from_offsets(nlp.make_doc(text), entities)`"

View File

@ -329,8 +329,8 @@ def _fix_legacy_dict_data(example_dict):
for key, value in old_token_dict.items(): for key, value in old_token_dict.items():
if key in ("text", "ids", "brackets"): if key in ("text", "ids", "brackets"):
pass pass
elif key in remapping: elif key.lower() in remapping:
token_dict[remapping[key]] = value token_dict[remapping[key.lower()]] = value
else: else:
raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=remapping.keys())) raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=remapping.keys()))
text = example_dict.get("text", example_dict.get("raw")) text = example_dict.get("text", example_dict.get("raw"))

View File

@ -513,20 +513,23 @@ class Language(object):
): ):
"""Update the models in the pipeline. """Update the models in the pipeline.
examples (iterable): A batch of `Example` objects. examples (Iterable[Example]): A batch of examples
dummy: Should not be set - serves to catch backwards-incompatible scripts. dummy: Should not be set - serves to catch backwards-incompatible scripts.
drop (float): The dropout rate. drop (float): The dropout rate.
sgd (callable): An optimizer. sgd (Optimizer): An optimizer.
losses (dict): Dictionary to update with the loss, keyed by component. losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
component_cfg (dict): Config parameters for specific pipeline component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
components, keyed by component name. components, keyed by component name.
RETURNS (Dict[str, float]): The updated losses dictionary
DOCS: https://spacy.io/api/language#update DOCS: https://spacy.io/api/language#update
""" """
if dummy is not None: if dummy is not None:
raise ValueError(Errors.E989) raise ValueError(Errors.E989)
if losses is None:
losses = {}
if len(examples) == 0: if len(examples) == 0:
return return losses
if not isinstance(examples, Iterable): if not isinstance(examples, Iterable):
raise TypeError(Errors.E978.format(name="language", method="update", types=type(examples))) raise TypeError(Errors.E978.format(name="language", method="update", types=type(examples)))
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)]) wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
@ -552,6 +555,7 @@ class Language(object):
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "model"): if hasattr(proc, "model"):
proc.model.finish_update(sgd) proc.model.finish_update(sgd)
return losses
def rehearse(self, examples, sgd=None, losses=None, config=None): def rehearse(self, examples, sgd=None, losses=None, config=None):
"""Make a "rehearsal" update to the models in the pipeline, to prevent """Make a "rehearsal" update to the models in the pipeline, to prevent
@ -757,18 +761,17 @@ class Language(object):
): ):
"""Process texts as a stream, and yield `Doc` objects in order. """Process texts as a stream, and yield `Doc` objects in order.
texts (iterator): A sequence of texts to process. texts (Iterable[str]): A sequence of texts to process.
as_tuples (bool): If set to True, inputs should be a sequence of as_tuples (bool): If set to True, inputs should be a sequence of
(text, context) tuples. Output will then be a sequence of (text, context) tuples. Output will then be a sequence of
(doc, context) tuples. Defaults to False. (doc, context) tuples. Defaults to False.
batch_size (int): The number of texts to buffer. batch_size (int): The number of texts to buffer.
disable (list): Names of the pipeline components to disable. disable (List[str]): Names of the pipeline components to disable.
cleanup (bool): If True, unneeded strings are freed to control memory cleanup (bool): If True, unneeded strings are freed to control memory
use. Experimental. use. Experimental.
component_cfg (dict): An optional dictionary with extra keyword component_cfg (Dict[str, Dict]): An optional dictionary with extra keyword
arguments for specific components. arguments for specific components.
n_process (int): Number of processors to process texts, only supported n_process (int): Number of processors to process texts. If -1, set `multiprocessing.cpu_count()`.
in Python3. If -1, set `multiprocessing.cpu_count()`.
YIELDS (Doc): Documents in the order of the original text. YIELDS (Doc): Documents in the order of the original text.
DOCS: https://spacy.io/api/language#pipe DOCS: https://spacy.io/api/language#pipe

View File

@ -58,12 +58,8 @@ class Pipe(object):
Both __call__ and pipe should delegate to the `predict()` Both __call__ and pipe should delegate to the `predict()`
and `set_annotations()` methods. and `set_annotations()` methods.
""" """
predictions = self.predict([doc]) scores = self.predict([doc])
if isinstance(predictions, tuple) and len(predictions) == 2: self.set_annotations([doc], scores)
scores, tensors = predictions
self.set_annotations([doc], scores, tensors=tensors)
else:
self.set_annotations([doc], predictions)
return doc return doc
def pipe(self, stream, batch_size=128): def pipe(self, stream, batch_size=128):
@ -73,12 +69,8 @@ class Pipe(object):
and `set_annotations()` methods. and `set_annotations()` methods.
""" """
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
predictions = self.predict(docs) scores = self.predict(docs)
if isinstance(predictions, tuple) and len(tuple) == 2: self.set_annotations(docs, scores)
scores, tensors = predictions
self.set_annotations(docs, scores, tensors=tensors)
else:
self.set_annotations(docs, predictions)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
@ -87,7 +79,7 @@ class Pipe(object):
""" """
raise NotImplementedError raise NotImplementedError
def set_annotations(self, docs, scores, tensors=None): def set_annotations(self, docs, scores):
"""Modify a batch of documents, using pre-computed scores.""" """Modify a batch of documents, using pre-computed scores."""
raise NotImplementedError raise NotImplementedError
@ -281,9 +273,10 @@ class Tagger(Pipe):
idx += 1 idx += 1
doc.is_tagged = True doc.is_tagged = True
def update(self, examples, drop=0., sgd=None, losses=None, set_annotations=False): def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
if losses is not None and self.name not in losses: if losses is None:
losses[self.name] = 0. losses = {}
losses.setdefault(self.name, 0.0)
try: try:
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
@ -303,11 +296,11 @@ class Tagger(Pipe):
if sgd not in (None, False): if sgd not in (None, False):
self.model.finish_update(sgd) self.model.finish_update(sgd)
if losses is not None: losses[self.name] += loss
losses[self.name] += loss
if set_annotations: if set_annotations:
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
self.set_annotations(docs, self._scores2guesses(tag_scores)) self.set_annotations(docs, self._scores2guesses(tag_scores))
return losses
def rehearse(self, examples, drop=0., sgd=None, losses=None): def rehearse(self, examples, drop=0., sgd=None, losses=None):
"""Perform a 'rehearsal' update, where we try to match the output of """Perform a 'rehearsal' update, where we try to match the output of
@ -635,7 +628,7 @@ class MultitaskObjective(Tagger):
def labels(self, value): def labels(self, value):
self.cfg["labels"] = value self.cfg["labels"] = value
def set_annotations(self, docs, dep_ids, tensors=None): def set_annotations(self, docs, dep_ids):
pass pass
def begin_training(self, get_examples=lambda: [], pipeline=None, def begin_training(self, get_examples=lambda: [], pipeline=None,
@ -732,7 +725,7 @@ class ClozeMultitask(Pipe):
self.cfg = cfg self.cfg = cfg
self.distance = CosineDistance(ignore_zeros=True, normalize=False) # TODO: in config self.distance = CosineDistance(ignore_zeros=True, normalize=False) # TODO: in config
def set_annotations(self, docs, dep_ids, tensors=None): def set_annotations(self, docs, dep_ids):
pass pass
def begin_training(self, get_examples=lambda: [], pipeline=None, def begin_training(self, get_examples=lambda: [], pipeline=None,
@ -761,7 +754,7 @@ class ClozeMultitask(Pipe):
loss = self.distance.get_loss(prediction, target) loss = self.distance.get_loss(prediction, target)
return loss, gradient return loss, gradient
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None): def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
pass pass
def rehearse(self, examples, drop=0., sgd=None, losses=None): def rehearse(self, examples, drop=0., sgd=None, losses=None):
@ -809,8 +802,8 @@ class TextCategorizer(Pipe):
def pipe(self, stream, batch_size=128): def pipe(self, stream, batch_size=128):
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
scores, tensors = self.predict(docs) scores = self.predict(docs)
self.set_annotations(docs, scores, tensors=tensors) self.set_annotations(docs, scores)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
@ -820,22 +813,25 @@ class TextCategorizer(Pipe):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
xp = get_array_module(tensors) xp = get_array_module(tensors)
scores = xp.zeros((len(docs), len(self.labels))) scores = xp.zeros((len(docs), len(self.labels)))
return scores, tensors return scores
scores = self.model.predict(docs) scores = self.model.predict(docs)
scores = self.model.ops.asarray(scores) scores = self.model.ops.asarray(scores)
return scores, tensors return scores
def set_annotations(self, docs, scores, tensors=None): def set_annotations(self, docs, scores):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
doc.cats[label] = float(scores[i, j]) doc.cats[label] = float(scores[i, j])
def update(self, examples, state=None, drop=0., set_annotations=False, sgd=None, losses=None): def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
try: try:
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return losses
except AttributeError: except AttributeError:
types = set([type(eg) for eg in examples]) types = set([type(eg) for eg in examples])
raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=types)) raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=types))
@ -847,12 +843,11 @@ class TextCategorizer(Pipe):
bp_scores(d_scores) bp_scores(d_scores)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.model.finish_update(sgd)
if losses is not None: losses[self.name] += loss
losses.setdefault(self.name, 0.0)
losses[self.name] += loss
if set_annotations: if set_annotations:
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
self.set_annotations(docs, scores=scores) self.set_annotations(docs, scores=scores)
return losses
def rehearse(self, examples, drop=0., sgd=None, losses=None): def rehearse(self, examples, drop=0., sgd=None, losses=None):
if self._rehearsal_model is None: if self._rehearsal_model is None:
@ -1076,12 +1071,13 @@ class EntityLinker(Pipe):
sgd = self.create_optimizer() sgd = self.create_optimizer()
return sgd return sgd
def update(self, examples, state=None, set_annotations=False, drop=0.0, sgd=None, losses=None): def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None):
self.require_kb() self.require_kb()
if losses is not None: if losses is None:
losses.setdefault(self.name, 0.0) losses = {}
losses.setdefault(self.name, 0.0)
if not examples: if not examples:
return 0 return losses
sentence_docs = [] sentence_docs = []
try: try:
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
@ -1124,20 +1120,19 @@ class EntityLinker(Pipe):
return 0.0 return 0.0
sentence_encodings, bp_context = self.model.begin_update(sentence_docs) sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
loss, d_scores = self.get_similarity_loss( loss, d_scores = self.get_similarity_loss(
scores=sentence_encodings, sentence_encodings=sentence_encodings,
examples=examples examples=examples
) )
bp_context(d_scores) bp_context(d_scores)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.model.finish_update(sgd)
if losses is not None: losses[self.name] += loss
losses[self.name] += loss
if set_annotations: if set_annotations:
self.set_annotations(docs, predictions) self.set_annotations(docs, predictions)
return loss return losses
def get_similarity_loss(self, examples, scores): def get_similarity_loss(self, examples, sentence_encodings):
entity_encodings = [] entity_encodings = []
for eg in examples: for eg in examples:
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
@ -1149,41 +1144,23 @@ class EntityLinker(Pipe):
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
if scores.shape != entity_encodings.shape: if sentence_encodings.shape != entity_encodings.shape:
raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up")) raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up"))
gradients = self.distance.get_grad(scores, entity_encodings) gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
loss = self.distance.get_loss(scores, entity_encodings) loss = self.distance.get_loss(sentence_encodings, entity_encodings)
loss = loss / len(entity_encodings) loss = loss / len(entity_encodings)
return loss, gradients return loss, gradients
def get_loss(self, examples, scores):
cats = []
for eg in examples:
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
for ent in eg.predicted.ents:
kb_id = kb_ids[ent.start]
if kb_id:
cats.append([1.0])
cats = self.model.ops.asarray(cats, dtype="float32")
if len(scores) != len(cats):
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
d_scores = (scores - cats)
loss = (d_scores ** 2).sum()
loss = loss / len(cats)
return loss, d_scores
def __call__(self, doc): def __call__(self, doc):
kb_ids, tensors = self.predict([doc]) kb_ids = self.predict([doc])
self.set_annotations([doc], kb_ids, tensors=tensors) self.set_annotations([doc], kb_ids)
return doc return doc
def pipe(self, stream, batch_size=128): def pipe(self, stream, batch_size=128):
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
kb_ids, tensors = self.predict(docs) kb_ids = self.predict(docs)
self.set_annotations(docs, kb_ids, tensors=tensors) self.set_annotations(docs, kb_ids)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
@ -1191,10 +1168,9 @@ class EntityLinker(Pipe):
self.require_kb() self.require_kb()
entity_count = 0 entity_count = 0
final_kb_ids = [] final_kb_ids = []
final_tensors = []
if not docs: if not docs:
return final_kb_ids, final_tensors return final_kb_ids
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
@ -1228,21 +1204,18 @@ class EntityLinker(Pipe):
if to_discard and ent.label_ in to_discard: if to_discard and ent.label_ in to_discard:
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding)
else: else:
candidates = self.kb.get_candidates(ent.text) candidates = self.kb.get_candidates(ent.text)
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding)
elif len(candidates) == 1: elif len(candidates) == 1:
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
# TODO: thresholding # TODO: thresholding
final_kb_ids.append(candidates[0].entity_) final_kb_ids.append(candidates[0].entity_)
final_tensors.append(sentence_encoding)
else: else:
random.shuffle(candidates) random.shuffle(candidates)
@ -1271,14 +1244,13 @@ class EntityLinker(Pipe):
best_index = scores.argmax().item() best_index = scores.argmax().item()
best_candidate = candidates[best_index] best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_) final_kb_ids.append(best_candidate.entity_)
final_tensors.append(sentence_encoding)
if not (len(final_tensors) == len(final_kb_ids) == entity_count): if not (len(final_kb_ids) == entity_count):
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length")) raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
return final_kb_ids, final_tensors return final_kb_ids
def set_annotations(self, docs, kb_ids, tensors=None): def set_annotations(self, docs, kb_ids):
count_ents = len([ent for doc in docs for ent in doc.ents]) count_ents = len([ent for doc in docs for ent in doc.ents])
if count_ents != len(kb_ids): if count_ents != len(kb_ids):
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids))) raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
@ -1394,11 +1366,7 @@ class Sentencizer(Pipe):
def pipe(self, stream, batch_size=128): def pipe(self, stream, batch_size=128):
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
predictions = self.predict(docs) predictions = self.predict(docs)
if isinstance(predictions, tuple) and len(tuple) == 2: self.set_annotations(docs, predictions)
scores, tensors = predictions
self.set_annotations(docs, scores, tensors=tensors)
else:
self.set_annotations(docs, predictions)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
@ -1429,7 +1397,7 @@ class Sentencizer(Pipe):
guesses.append(doc_guesses) guesses.append(doc_guesses)
return guesses return guesses
def set_annotations(self, docs, batch_tag_ids, tensors=None): def set_annotations(self, docs, batch_tag_ids):
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc

View File

@ -57,7 +57,7 @@ class SimpleNER(Pipe):
scores = self.model.predict(docs) scores = self.model.predict(docs)
return scores return scores
def set_annotations(self, docs: List[Doc], scores: List[Floats2d], tensors=None): def set_annotations(self, docs: List[Doc], scores: List[Floats2d]):
"""Set entities on a batch of documents from a batch of scores.""" """Set entities on a batch of documents from a batch of scores."""
tag_names = self.get_tag_names() tag_names = self.get_tag_names()
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
@ -67,9 +67,12 @@ class SimpleNER(Pipe):
tags = iob_to_biluo(tags) tags = iob_to_biluo(tags)
doc.ents = spans_from_biluo_tags(doc, tags) doc.ents = spans_from_biluo_tags(doc, tags)
def update(self, examples, set_annotations=False, drop=0.0, sgd=None, losses=None): def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None):
if losses is None:
losses = {}
losses.setdefault("ner", 0.0)
if not any(_has_ner(eg) for eg in examples): if not any(_has_ner(eg) for eg in examples):
return 0 return losses
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update(docs) scores, bp_scores = self.model.begin_update(docs)
@ -79,10 +82,8 @@ class SimpleNER(Pipe):
self.set_annotations(docs, scores) self.set_annotations(docs, scores)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.model.finish_update(sgd)
if losses is not None: losses["ner"] += loss
losses.setdefault("ner", 0.0) return losses
losses["ner"] += loss
return loss
def get_loss(self, examples, scores): def get_loss(self, examples, scores):
loss = 0 loss = 0

View File

@ -83,12 +83,14 @@ class Tok2Vec(Pipe):
assert tokvecs.shape[0] == len(doc) assert tokvecs.shape[0] == len(doc)
doc.tensor = tokvecs doc.tensor = tokvecs
def update(self, examples, drop=0.0, sgd=None, losses=None, set_annotations=False): def update(self, examples, *, drop=0.0, sgd=None, losses=None, set_annotations=False):
"""Update the model. """Update the model.
examples (iterable): A batch of examples examples (Iterable[Example]): A batch of examples
drop (float): The droput rate. drop (float): The droput rate.
sgd (callable): An optimizer. sgd (Optimizer): An optimizer.
RETURNS (dict): Results from the update. losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
set_annotations (bool): whether or not to update the examples with the predictions
RETURNS (Dict[str, float]): The updated losses dictionary
""" """
if losses is None: if losses is None:
losses = {} losses = {}
@ -124,6 +126,7 @@ class Tok2Vec(Pipe):
self.listeners[-1].receive(batch_id, tokvecs, backprop) self.listeners[-1].receive(batch_id, tokvecs, backprop)
if set_annotations: if set_annotations:
self.set_annotations(docs, tokvecs) self.set_annotations(docs, tokvecs)
return losses
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
pass pass

View File

@ -153,7 +153,7 @@ cdef class Parser:
doc (Doc): The document to be processed. doc (Doc): The document to be processed.
""" """
states = self.predict([doc]) states = self.predict([doc])
self.set_annotations([doc], states, tensors=None) self.set_annotations([doc], states)
return doc return doc
def pipe(self, docs, int batch_size=256): def pipe(self, docs, int batch_size=256):
@ -170,7 +170,7 @@ cdef class Parser:
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)): for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
subbatch = list(subbatch) subbatch = list(subbatch)
parse_states = self.predict(subbatch) parse_states = self.predict(subbatch)
self.set_annotations(subbatch, parse_states, tensors=None) self.set_annotations(subbatch, parse_states)
yield from batch_in_order yield from batch_in_order
def predict(self, docs): def predict(self, docs):
@ -222,7 +222,7 @@ cdef class Parser:
unfinished.clear() unfinished.clear()
free_activations(&activations) free_activations(&activations)
def set_annotations(self, docs, states, tensors=None): def set_annotations(self, docs, states):
cdef StateClass state cdef StateClass state
cdef Doc doc cdef Doc doc
for i, (state, doc) in enumerate(zip(states, docs)): for i, (state, doc) in enumerate(zip(states, docs)):
@ -263,7 +263,7 @@ cdef class Parser:
states[i].push_hist(guess) states[i].push_hist(guess)
free(is_valid) free(is_valid)
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None): def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
cdef StateClass state cdef StateClass state
if losses is None: if losses is None:
losses = {} losses = {}

View File

@ -302,7 +302,7 @@ def test_multiple_predictions():
def predict(self, docs): def predict(self, docs):
return ([1, 2, 3], [4, 5, 6]) return ([1, 2, 3], [4, 5, 6])
def set_annotations(self, docs, scores, tensors=None): def set_annotations(self, docs, scores):
return docs return docs
nlp = Language() nlp = Language()

View File

@ -1,3 +1,4 @@
import numpy
from spacy.errors import AlignmentError from spacy.errors import AlignmentError
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
from spacy.gold import spans_from_biluo_tags, iob_to_biluo from spacy.gold import spans_from_biluo_tags, iob_to_biluo
@ -154,6 +155,27 @@ def test_gold_biluo_misalign(en_vocab):
assert tags == ["O", "O", "O", "-", "-", "-"] assert tags == ["O", "O", "O", "-", "-", "-"]
def test_example_constructor(en_vocab):
words = ["I", "like", "stuff"]
tags = ["NOUN", "VERB", "NOUN"]
tag_ids = [en_vocab.strings.add(tag) for tag in tags]
predicted = Doc(en_vocab, words=words)
reference = Doc(en_vocab, words=words)
reference = reference.from_array("TAG", numpy.array(tag_ids, dtype="uint64"))
example = Example(predicted, reference)
tags = example.get_aligned("TAG", as_string=True)
assert tags == ["NOUN", "VERB", "NOUN"]
def test_example_from_dict_tags(en_vocab):
words = ["I", "like", "stuff"]
tags = ["NOUN", "VERB", "NOUN"]
predicted = Doc(en_vocab, words=words)
example = Example.from_dict(predicted, {"TAGS": tags})
tags = example.get_aligned("TAG", as_string=True)
assert tags == ["NOUN", "VERB", "NOUN"]
def test_example_from_dict_no_ner(en_vocab): def test_example_from_dict_no_ner(en_vocab):
words = ["a", "b", "c", "d"] words = ["a", "b", "c", "d"]
spaces = [True, True, False, True] spaces = [True, True, False, True]

View File

@ -803,7 +803,7 @@ cdef class Doc:
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_) attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
for id_ in attrs] for id_ in attrs]
if array.dtype != numpy.uint64: if array.dtype != numpy.uint64:
warnings.warn(Warnings.W101.format(type=array.dtype)) warnings.warn(Warnings.W028.format(type=array.dtype))
if SENT_START in attrs and HEAD in attrs: if SENT_START in attrs and HEAD in attrs:
raise ValueError(Errors.E032) raise ValueError(Errors.E032)