mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
cleanup components API (#5726)
* add keyword separator for update functions and drop unused "state" * few more Example tests and various small fixes * consistently return losses after update call * eliminate unused tensors field across pipe components * fix name * fix arg name
This commit is contained in:
parent
ac4297ee39
commit
dd207a28be
|
@ -69,6 +69,9 @@ class Warnings(object):
|
|||
W027 = ("Found a large training file of {size} bytes. Note that it may "
|
||||
"be more efficient to split your training data into multiple "
|
||||
"smaller JSON files instead.")
|
||||
W028 = ("Doc.from_array was called with a vector of type '{type}', "
|
||||
"but is expecting one of type 'uint64' instead. This may result "
|
||||
"in problems with the vocab further on in the pipeline.")
|
||||
W030 = ("Some entities could not be aligned in the text \"{text}\" with "
|
||||
"entities \"{entities}\". Use "
|
||||
"`spacy.gold.biluo_tags_from_offsets(nlp.make_doc(text), entities)`"
|
||||
|
|
|
@ -329,8 +329,8 @@ def _fix_legacy_dict_data(example_dict):
|
|||
for key, value in old_token_dict.items():
|
||||
if key in ("text", "ids", "brackets"):
|
||||
pass
|
||||
elif key in remapping:
|
||||
token_dict[remapping[key]] = value
|
||||
elif key.lower() in remapping:
|
||||
token_dict[remapping[key.lower()]] = value
|
||||
else:
|
||||
raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=remapping.keys()))
|
||||
text = example_dict.get("text", example_dict.get("raw"))
|
||||
|
|
|
@ -513,20 +513,23 @@ class Language(object):
|
|||
):
|
||||
"""Update the models in the pipeline.
|
||||
|
||||
examples (iterable): A batch of `Example` objects.
|
||||
examples (Iterable[Example]): A batch of examples
|
||||
dummy: Should not be set - serves to catch backwards-incompatible scripts.
|
||||
drop (float): The dropout rate.
|
||||
sgd (callable): An optimizer.
|
||||
losses (dict): Dictionary to update with the loss, keyed by component.
|
||||
component_cfg (dict): Config parameters for specific pipeline
|
||||
sgd (Optimizer): An optimizer.
|
||||
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
||||
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
||||
components, keyed by component name.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary
|
||||
|
||||
DOCS: https://spacy.io/api/language#update
|
||||
"""
|
||||
if dummy is not None:
|
||||
raise ValueError(Errors.E989)
|
||||
if losses is None:
|
||||
losses = {}
|
||||
if len(examples) == 0:
|
||||
return
|
||||
return losses
|
||||
if not isinstance(examples, Iterable):
|
||||
raise TypeError(Errors.E978.format(name="language", method="update", types=type(examples)))
|
||||
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
|
||||
|
@ -552,6 +555,7 @@ class Language(object):
|
|||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "model"):
|
||||
proc.model.finish_update(sgd)
|
||||
return losses
|
||||
|
||||
def rehearse(self, examples, sgd=None, losses=None, config=None):
|
||||
"""Make a "rehearsal" update to the models in the pipeline, to prevent
|
||||
|
@ -757,18 +761,17 @@ class Language(object):
|
|||
):
|
||||
"""Process texts as a stream, and yield `Doc` objects in order.
|
||||
|
||||
texts (iterator): A sequence of texts to process.
|
||||
texts (Iterable[str]): A sequence of texts to process.
|
||||
as_tuples (bool): If set to True, inputs should be a sequence of
|
||||
(text, context) tuples. Output will then be a sequence of
|
||||
(doc, context) tuples. Defaults to False.
|
||||
batch_size (int): The number of texts to buffer.
|
||||
disable (list): Names of the pipeline components to disable.
|
||||
disable (List[str]): Names of the pipeline components to disable.
|
||||
cleanup (bool): If True, unneeded strings are freed to control memory
|
||||
use. Experimental.
|
||||
component_cfg (dict): An optional dictionary with extra keyword
|
||||
component_cfg (Dict[str, Dict]): An optional dictionary with extra keyword
|
||||
arguments for specific components.
|
||||
n_process (int): Number of processors to process texts, only supported
|
||||
in Python3. If -1, set `multiprocessing.cpu_count()`.
|
||||
n_process (int): Number of processors to process texts. If -1, set `multiprocessing.cpu_count()`.
|
||||
YIELDS (Doc): Documents in the order of the original text.
|
||||
|
||||
DOCS: https://spacy.io/api/language#pipe
|
||||
|
|
|
@ -58,12 +58,8 @@ class Pipe(object):
|
|||
Both __call__ and pipe should delegate to the `predict()`
|
||||
and `set_annotations()` methods.
|
||||
"""
|
||||
predictions = self.predict([doc])
|
||||
if isinstance(predictions, tuple) and len(predictions) == 2:
|
||||
scores, tensors = predictions
|
||||
self.set_annotations([doc], scores, tensors=tensors)
|
||||
else:
|
||||
self.set_annotations([doc], predictions)
|
||||
scores = self.predict([doc])
|
||||
self.set_annotations([doc], scores)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128):
|
||||
|
@ -73,12 +69,8 @@ class Pipe(object):
|
|||
and `set_annotations()` methods.
|
||||
"""
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
predictions = self.predict(docs)
|
||||
if isinstance(predictions, tuple) and len(tuple) == 2:
|
||||
scores, tensors = predictions
|
||||
self.set_annotations(docs, scores, tensors=tensors)
|
||||
else:
|
||||
self.set_annotations(docs, predictions)
|
||||
scores = self.predict(docs)
|
||||
self.set_annotations(docs, scores)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
|
@ -87,7 +79,7 @@ class Pipe(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_annotations(self, docs, scores, tensors=None):
|
||||
def set_annotations(self, docs, scores):
|
||||
"""Modify a batch of documents, using pre-computed scores."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -281,9 +273,10 @@ class Tagger(Pipe):
|
|||
idx += 1
|
||||
doc.is_tagged = True
|
||||
|
||||
def update(self, examples, drop=0., sgd=None, losses=None, set_annotations=False):
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
try:
|
||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
|
@ -303,11 +296,11 @@ class Tagger(Pipe):
|
|||
if sgd not in (None, False):
|
||||
self.model.finish_update(sgd)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
losses[self.name] += loss
|
||||
if set_annotations:
|
||||
docs = [eg.predicted for eg in examples]
|
||||
self.set_annotations(docs, self._scores2guesses(tag_scores))
|
||||
return losses
|
||||
|
||||
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
||||
"""Perform a 'rehearsal' update, where we try to match the output of
|
||||
|
@ -635,7 +628,7 @@ class MultitaskObjective(Tagger):
|
|||
def labels(self, value):
|
||||
self.cfg["labels"] = value
|
||||
|
||||
def set_annotations(self, docs, dep_ids, tensors=None):
|
||||
def set_annotations(self, docs, dep_ids):
|
||||
pass
|
||||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None,
|
||||
|
@ -732,7 +725,7 @@ class ClozeMultitask(Pipe):
|
|||
self.cfg = cfg
|
||||
self.distance = CosineDistance(ignore_zeros=True, normalize=False) # TODO: in config
|
||||
|
||||
def set_annotations(self, docs, dep_ids, tensors=None):
|
||||
def set_annotations(self, docs, dep_ids):
|
||||
pass
|
||||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None,
|
||||
|
@ -761,7 +754,7 @@ class ClozeMultitask(Pipe):
|
|||
loss = self.distance.get_loss(prediction, target)
|
||||
return loss, gradient
|
||||
|
||||
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||
pass
|
||||
|
||||
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
||||
|
@ -809,8 +802,8 @@ class TextCategorizer(Pipe):
|
|||
|
||||
def pipe(self, stream, batch_size=128):
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
scores, tensors = self.predict(docs)
|
||||
self.set_annotations(docs, scores, tensors=tensors)
|
||||
scores = self.predict(docs)
|
||||
self.set_annotations(docs, scores)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
|
@ -820,22 +813,25 @@ class TextCategorizer(Pipe):
|
|||
# Handle cases where there are no tokens in any docs.
|
||||
xp = get_array_module(tensors)
|
||||
scores = xp.zeros((len(docs), len(self.labels)))
|
||||
return scores, tensors
|
||||
return scores
|
||||
|
||||
scores = self.model.predict(docs)
|
||||
scores = self.model.ops.asarray(scores)
|
||||
return scores, tensors
|
||||
return scores
|
||||
|
||||
def set_annotations(self, docs, scores, tensors=None):
|
||||
def set_annotations(self, docs, scores):
|
||||
for i, doc in enumerate(docs):
|
||||
for j, label in enumerate(self.labels):
|
||||
doc.cats[label] = float(scores[i, j])
|
||||
|
||||
def update(self, examples, state=None, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
try:
|
||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
return losses
|
||||
except AttributeError:
|
||||
types = set([type(eg) for eg in examples])
|
||||
raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=types))
|
||||
|
@ -847,12 +843,11 @@ class TextCategorizer(Pipe):
|
|||
bp_scores(d_scores)
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += loss
|
||||
losses[self.name] += loss
|
||||
if set_annotations:
|
||||
docs = [eg.predicted for eg in examples]
|
||||
self.set_annotations(docs, scores=scores)
|
||||
return losses
|
||||
|
||||
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
||||
if self._rehearsal_model is None:
|
||||
|
@ -1076,12 +1071,13 @@ class EntityLinker(Pipe):
|
|||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def update(self, examples, state=None, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
||||
def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
||||
self.require_kb()
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
if not examples:
|
||||
return 0
|
||||
return losses
|
||||
sentence_docs = []
|
||||
try:
|
||||
docs = [eg.predicted for eg in examples]
|
||||
|
@ -1124,20 +1120,19 @@ class EntityLinker(Pipe):
|
|||
return 0.0
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||
loss, d_scores = self.get_similarity_loss(
|
||||
scores=sentence_encodings,
|
||||
sentence_encodings=sentence_encodings,
|
||||
examples=examples
|
||||
)
|
||||
bp_context(d_scores)
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
losses[self.name] += loss
|
||||
if set_annotations:
|
||||
self.set_annotations(docs, predictions)
|
||||
return loss
|
||||
return losses
|
||||
|
||||
def get_similarity_loss(self, examples, scores):
|
||||
def get_similarity_loss(self, examples, sentence_encodings):
|
||||
entity_encodings = []
|
||||
for eg in examples:
|
||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||
|
@ -1149,41 +1144,23 @@ class EntityLinker(Pipe):
|
|||
|
||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
|
||||
if scores.shape != entity_encodings.shape:
|
||||
if sentence_encodings.shape != entity_encodings.shape:
|
||||
raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up"))
|
||||
|
||||
gradients = self.distance.get_grad(scores, entity_encodings)
|
||||
loss = self.distance.get_loss(scores, entity_encodings)
|
||||
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
|
||||
loss = self.distance.get_loss(sentence_encodings, entity_encodings)
|
||||
loss = loss / len(entity_encodings)
|
||||
return loss, gradients
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
cats = []
|
||||
for eg in examples:
|
||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||
for ent in eg.predicted.ents:
|
||||
kb_id = kb_ids[ent.start]
|
||||
if kb_id:
|
||||
cats.append([1.0])
|
||||
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
if len(scores) != len(cats):
|
||||
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
|
||||
|
||||
d_scores = (scores - cats)
|
||||
loss = (d_scores ** 2).sum()
|
||||
loss = loss / len(cats)
|
||||
return loss, d_scores
|
||||
|
||||
def __call__(self, doc):
|
||||
kb_ids, tensors = self.predict([doc])
|
||||
self.set_annotations([doc], kb_ids, tensors=tensors)
|
||||
kb_ids = self.predict([doc])
|
||||
self.set_annotations([doc], kb_ids)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128):
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
kb_ids, tensors = self.predict(docs)
|
||||
self.set_annotations(docs, kb_ids, tensors=tensors)
|
||||
kb_ids = self.predict(docs)
|
||||
self.set_annotations(docs, kb_ids)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
|
@ -1191,10 +1168,9 @@ class EntityLinker(Pipe):
|
|||
self.require_kb()
|
||||
entity_count = 0
|
||||
final_kb_ids = []
|
||||
final_tensors = []
|
||||
|
||||
if not docs:
|
||||
return final_kb_ids, final_tensors
|
||||
return final_kb_ids
|
||||
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
@ -1228,21 +1204,18 @@ class EntityLinker(Pipe):
|
|||
if to_discard and ent.label_ in to_discard:
|
||||
# ignoring this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
else:
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if not candidates:
|
||||
# no prediction possible for this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
elif len(candidates) == 1:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
|
||||
# TODO: thresholding
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
@ -1271,14 +1244,13 @@ class EntityLinker(Pipe):
|
|||
best_index = scores.argmax().item()
|
||||
best_candidate = candidates[best_index]
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
||||
if not (len(final_kb_ids) == entity_count):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
||||
|
||||
return final_kb_ids, final_tensors
|
||||
return final_kb_ids
|
||||
|
||||
def set_annotations(self, docs, kb_ids, tensors=None):
|
||||
def set_annotations(self, docs, kb_ids):
|
||||
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||
if count_ents != len(kb_ids):
|
||||
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||
|
@ -1394,11 +1366,7 @@ class Sentencizer(Pipe):
|
|||
def pipe(self, stream, batch_size=128):
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
predictions = self.predict(docs)
|
||||
if isinstance(predictions, tuple) and len(tuple) == 2:
|
||||
scores, tensors = predictions
|
||||
self.set_annotations(docs, scores, tensors=tensors)
|
||||
else:
|
||||
self.set_annotations(docs, predictions)
|
||||
self.set_annotations(docs, predictions)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
|
@ -1429,7 +1397,7 @@ class Sentencizer(Pipe):
|
|||
guesses.append(doc_guesses)
|
||||
return guesses
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids, tensors=None):
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
|
|
@ -57,7 +57,7 @@ class SimpleNER(Pipe):
|
|||
scores = self.model.predict(docs)
|
||||
return scores
|
||||
|
||||
def set_annotations(self, docs: List[Doc], scores: List[Floats2d], tensors=None):
|
||||
def set_annotations(self, docs: List[Doc], scores: List[Floats2d]):
|
||||
"""Set entities on a batch of documents from a batch of scores."""
|
||||
tag_names = self.get_tag_names()
|
||||
for i, doc in enumerate(docs):
|
||||
|
@ -67,9 +67,12 @@ class SimpleNER(Pipe):
|
|||
tags = iob_to_biluo(tags)
|
||||
doc.ents = spans_from_biluo_tags(doc, tags)
|
||||
|
||||
def update(self, examples, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
||||
def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault("ner", 0.0)
|
||||
if not any(_has_ner(eg) for eg in examples):
|
||||
return 0
|
||||
return losses
|
||||
docs = [eg.predicted for eg in examples]
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, bp_scores = self.model.begin_update(docs)
|
||||
|
@ -79,10 +82,8 @@ class SimpleNER(Pipe):
|
|||
self.set_annotations(docs, scores)
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
if losses is not None:
|
||||
losses.setdefault("ner", 0.0)
|
||||
losses["ner"] += loss
|
||||
return loss
|
||||
losses["ner"] += loss
|
||||
return losses
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
loss = 0
|
||||
|
|
|
@ -83,12 +83,14 @@ class Tok2Vec(Pipe):
|
|||
assert tokvecs.shape[0] == len(doc)
|
||||
doc.tensor = tokvecs
|
||||
|
||||
def update(self, examples, drop=0.0, sgd=None, losses=None, set_annotations=False):
|
||||
def update(self, examples, *, drop=0.0, sgd=None, losses=None, set_annotations=False):
|
||||
"""Update the model.
|
||||
examples (iterable): A batch of examples
|
||||
examples (Iterable[Example]): A batch of examples
|
||||
drop (float): The droput rate.
|
||||
sgd (callable): An optimizer.
|
||||
RETURNS (dict): Results from the update.
|
||||
sgd (Optimizer): An optimizer.
|
||||
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
||||
set_annotations (bool): whether or not to update the examples with the predictions
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
|
@ -124,6 +126,7 @@ class Tok2Vec(Pipe):
|
|||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||
if set_annotations:
|
||||
self.set_annotations(docs, tokvecs)
|
||||
return losses
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
pass
|
||||
|
|
|
@ -153,7 +153,7 @@ cdef class Parser:
|
|||
doc (Doc): The document to be processed.
|
||||
"""
|
||||
states = self.predict([doc])
|
||||
self.set_annotations([doc], states, tensors=None)
|
||||
self.set_annotations([doc], states)
|
||||
return doc
|
||||
|
||||
def pipe(self, docs, int batch_size=256):
|
||||
|
@ -170,7 +170,7 @@ cdef class Parser:
|
|||
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
|
||||
subbatch = list(subbatch)
|
||||
parse_states = self.predict(subbatch)
|
||||
self.set_annotations(subbatch, parse_states, tensors=None)
|
||||
self.set_annotations(subbatch, parse_states)
|
||||
yield from batch_in_order
|
||||
|
||||
def predict(self, docs):
|
||||
|
@ -222,7 +222,7 @@ cdef class Parser:
|
|||
unfinished.clear()
|
||||
free_activations(&activations)
|
||||
|
||||
def set_annotations(self, docs, states, tensors=None):
|
||||
def set_annotations(self, docs, states):
|
||||
cdef StateClass state
|
||||
cdef Doc doc
|
||||
for i, (state, doc) in enumerate(zip(states, docs)):
|
||||
|
@ -263,7 +263,7 @@ cdef class Parser:
|
|||
states[i].push_hist(guess)
|
||||
free(is_valid)
|
||||
|
||||
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||
cdef StateClass state
|
||||
if losses is None:
|
||||
losses = {}
|
||||
|
|
|
@ -302,7 +302,7 @@ def test_multiple_predictions():
|
|||
def predict(self, docs):
|
||||
return ([1, 2, 3], [4, 5, 6])
|
||||
|
||||
def set_annotations(self, docs, scores, tensors=None):
|
||||
def set_annotations(self, docs, scores):
|
||||
return docs
|
||||
|
||||
nlp = Language()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import numpy
|
||||
from spacy.errors import AlignmentError
|
||||
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
|
||||
from spacy.gold import spans_from_biluo_tags, iob_to_biluo
|
||||
|
@ -154,6 +155,27 @@ def test_gold_biluo_misalign(en_vocab):
|
|||
assert tags == ["O", "O", "O", "-", "-", "-"]
|
||||
|
||||
|
||||
def test_example_constructor(en_vocab):
|
||||
words = ["I", "like", "stuff"]
|
||||
tags = ["NOUN", "VERB", "NOUN"]
|
||||
tag_ids = [en_vocab.strings.add(tag) for tag in tags]
|
||||
predicted = Doc(en_vocab, words=words)
|
||||
reference = Doc(en_vocab, words=words)
|
||||
reference = reference.from_array("TAG", numpy.array(tag_ids, dtype="uint64"))
|
||||
example = Example(predicted, reference)
|
||||
tags = example.get_aligned("TAG", as_string=True)
|
||||
assert tags == ["NOUN", "VERB", "NOUN"]
|
||||
|
||||
|
||||
def test_example_from_dict_tags(en_vocab):
|
||||
words = ["I", "like", "stuff"]
|
||||
tags = ["NOUN", "VERB", "NOUN"]
|
||||
predicted = Doc(en_vocab, words=words)
|
||||
example = Example.from_dict(predicted, {"TAGS": tags})
|
||||
tags = example.get_aligned("TAG", as_string=True)
|
||||
assert tags == ["NOUN", "VERB", "NOUN"]
|
||||
|
||||
|
||||
def test_example_from_dict_no_ner(en_vocab):
|
||||
words = ["a", "b", "c", "d"]
|
||||
spaces = [True, True, False, True]
|
||||
|
|
|
@ -803,7 +803,7 @@ cdef class Doc:
|
|||
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
||||
for id_ in attrs]
|
||||
if array.dtype != numpy.uint64:
|
||||
warnings.warn(Warnings.W101.format(type=array.dtype))
|
||||
warnings.warn(Warnings.W028.format(type=array.dtype))
|
||||
|
||||
if SENT_START in attrs and HEAD in attrs:
|
||||
raise ValueError(Errors.E032)
|
||||
|
|
Loading…
Reference in New Issue
Block a user