* Make GoldCorpus return dict, not Example

* Make Example require a Doc object (previously optional)

Clarify methods in GoldCorpus

WIP refactor Example

Refactor Example.split_sents

Fix test

Fix augment

Update test

Update test

Fix import

Update test_scorer

Update Example
This commit is contained in:
Matthew Honnibal 2020-06-08 22:28:50 +02:00
parent 084271c9e9
commit d9289712ba
11 changed files with 176 additions and 111 deletions

View File

@ -2,6 +2,7 @@ import re
from ...gold import Example from ...gold import Example
from ...gold import iob_to_biluo, spans_from_biluo_tags, biluo_tags_from_offsets from ...gold import iob_to_biluo, spans_from_biluo_tags, biluo_tags_from_offsets
from ...gold import TokenAnnotation
from ...language import Language from ...language import Language
from ...tokens import Doc, Token from ...tokens import Doc, Token
from .conll_ner2json import n_sents_info from .conll_ner2json import n_sents_info
@ -284,13 +285,8 @@ def example_from_conllu_sentence(
spaces.append(t._.merged_spaceafter) spaces.append(t._.merged_spaceafter)
ent_offsets = [(e.start_char, e.end_char, e.label_) for e in doc.ents] ent_offsets = [(e.start_char, e.end_char, e.label_) for e in doc.ents]
ents = biluo_tags_from_offsets(doc, ent_offsets) ents = biluo_tags_from_offsets(doc, ent_offsets)
raw = "" example = Example(doc=Doc(vocab, words=words, spaces=spaces))
for word, space in zip(words, spaces): example.token_annotation = TokenAnnotation(
raw += word
if space:
raw += " "
example = Example(doc=raw)
example.set_token_annotation(
ids=ids, ids=ids,
words=words, words=words,
tags=tags, tags=tags,

View File

@ -1,3 +1,6 @@
from .iob_utils import biluo_tags_from_offsets
class TokenAnnotation: class TokenAnnotation:
def __init__( def __init__(
self, self,

View File

@ -1,6 +1,7 @@
import random import random
import itertools import itertools
from .example import Example from .example import Example
from .annotation import TokenAnnotation
def make_orth_variants(nlp, example, orth_variant_level=0.0): def make_orth_variants(nlp, example, orth_variant_level=0.0):
@ -17,14 +18,14 @@ def make_orth_variants(nlp, example, orth_variant_level=0.0):
ndsv = nlp.Defaults.single_orth_variants ndsv = nlp.Defaults.single_orth_variants
ndpv = nlp.Defaults.paired_orth_variants ndpv = nlp.Defaults.paired_orth_variants
# modify words in paragraph_tuples # modify words in paragraph_tuples
variant_example = Example(doc=raw) variant_example = Example(doc=nlp.make_doc(raw))
token_annotation = example.token_annotation token_annotation = example.token_annotation
words = token_annotation.words words = token_annotation.words
tags = token_annotation.tags tags = token_annotation.tags
if not words or not tags: if not words or not tags:
# add the unmodified annotation # add the unmodified annotation
token_dict = token_annotation.to_dict() token_dict = token_annotation.to_dict()
variant_example.set_token_annotation(**token_dict) variant_example.token_annotation = TokenAnnotation(**token_dict)
else: else:
if lower: if lower:
words = [w.lower() for w in words] words = [w.lower() for w in words]
@ -60,7 +61,7 @@ def make_orth_variants(nlp, example, orth_variant_level=0.0):
token_dict = token_annotation.to_dict() token_dict = token_annotation.to_dict()
token_dict["words"] = words token_dict["words"] = words
token_dict["tags"] = tags token_dict["tags"] = tags
variant_example.set_token_annotation(**token_dict) variant_example.token_annotation = TokenAnnotation(**token_dict)
# modify raw to match variant_paragraph_tuples # modify raw to match variant_paragraph_tuples
if raw is not None: if raw is not None:
variants = [] variants = []

View File

@ -28,8 +28,8 @@ class GoldCorpus(object):
""" """
self.limit = limit self.limit = limit
if isinstance(train, str) or isinstance(train, Path): if isinstance(train, str) or isinstance(train, Path):
train = self.read_examples(self.walk_corpus(train)) train = self.read_annotations(self.walk_corpus(train))
dev = self.read_examples(self.walk_corpus(dev)) dev = self.read_annotations(self.walk_corpus(dev))
# Write temp directory with one doc per file, so we can shuffle and stream # Write temp directory with one doc per file, so we can shuffle and stream
self.tmp_dir = Path(tempfile.mkdtemp()) self.tmp_dir = Path(tempfile.mkdtemp())
self.write_msgpack(self.tmp_dir / "train", train, limit=self.limit) self.write_msgpack(self.tmp_dir / "train", train, limit=self.limit)
@ -71,7 +71,7 @@ class GoldCorpus(object):
return locs return locs
@staticmethod @staticmethod
def read_examples(locs, limit=0): def read_annotations(locs, limit=0):
""" Yield training examples """ """ Yield training examples """
i = 0 i = 0
for loc in locs: for loc in locs:
@ -101,11 +101,11 @@ class GoldCorpus(object):
or isinstance(doc, str) or isinstance(doc, str)
): ):
raise ValueError(Errors.E987.format(type=type(doc))) raise ValueError(Errors.E987.format(type=type(doc)))
examples.append(Example.from_dict(ex_dict, doc=doc)) examples.append(ex_dict)
elif file_name.endswith("msg"): elif file_name.endswith("msg"):
text, ex_dict = srsly.read_msgpack(loc) text, ex_dict = srsly.read_msgpack(loc)
examples = [Example.from_dict(ex_dict, doc=text)] examples = [ex_dict]
else: else:
supported = ("json", "jsonl", "msg") supported = ("json", "jsonl", "msg")
raise ValueError(Errors.E124.format(path=loc, formats=supported)) raise ValueError(Errors.E124.format(path=loc, formats=supported))
@ -123,21 +123,21 @@ class GoldCorpus(object):
raise ValueError(Errors.E996.format(file=file_name, msg=msg)) raise ValueError(Errors.E996.format(file=file_name, msg=msg))
@property @property
def dev_examples(self): def dev_annotations(self):
locs = (self.tmp_dir / "dev").iterdir() locs = (self.tmp_dir / "dev").iterdir()
yield from self.read_examples(locs, limit=self.limit) yield from self.read_annotations(locs, limit=self.limit)
@property @property
def train_examples(self): def train_annotations(self):
locs = (self.tmp_dir / "train").iterdir() locs = (self.tmp_dir / "train").iterdir()
yield from self.read_examples(locs, limit=self.limit) yield from self.read_annotations(locs, limit=self.limit)
def count_train(self): def count_train(self):
"""Returns count of words in train examples""" """Returns count of words in train examples"""
n = 0 n = 0
i = 0 i = 0
for example in self.train_examples: for eg_dict in self.train_annotations:
n += len(example.token_annotation.words) n += len(eg_dict["token_annotation"]["words"])
if self.limit and i >= self.limit: if self.limit and i >= self.limit:
break break
i += 1 i += 1
@ -154,10 +154,10 @@ class GoldCorpus(object):
): ):
locs = list((self.tmp_dir / "train").iterdir()) locs = list((self.tmp_dir / "train").iterdir())
random.shuffle(locs) random.shuffle(locs)
train_examples = self.read_examples(locs, limit=self.limit) train_annotations = self.read_annotations(locs, limit=self.limit)
gold_examples = self.iter_gold_docs( examples = self.iter_examples(
nlp, nlp,
train_examples, train_annotations,
gold_preproc, gold_preproc,
max_length=max_length, max_length=max_length,
noise_level=noise_level, noise_level=noise_level,
@ -165,33 +165,33 @@ class GoldCorpus(object):
make_projective=True, make_projective=True,
ignore_misaligned=ignore_misaligned, ignore_misaligned=ignore_misaligned,
) )
yield from gold_examples yield from examples
def train_dataset_without_preprocessing( def train_dataset_without_preprocessing(
self, nlp, gold_preproc=False, ignore_misaligned=False self, nlp, gold_preproc=False, ignore_misaligned=False
): ):
examples = self.iter_gold_docs( examples = self.iter_examples(
nlp, nlp,
self.train_examples, self.train_annotations,
gold_preproc=gold_preproc, gold_preproc=gold_preproc,
ignore_misaligned=ignore_misaligned, ignore_misaligned=ignore_misaligned,
) )
yield from examples yield from examples
def dev_dataset(self, nlp, gold_preproc=False, ignore_misaligned=False): def dev_dataset(self, nlp, gold_preproc=False, ignore_misaligned=False):
examples = self.iter_gold_docs( examples = self.iter_examples(
nlp, nlp,
self.dev_examples, self.dev_annotations,
gold_preproc=gold_preproc, gold_preproc=gold_preproc,
ignore_misaligned=ignore_misaligned, ignore_misaligned=ignore_misaligned,
) )
yield from examples yield from examples
@classmethod @classmethod
def iter_gold_docs( def iter_examples(
cls, cls,
nlp, nlp,
examples, annotations,
gold_preproc, gold_preproc,
max_length=None, max_length=None,
noise_level=0.0, noise_level=0.0,
@ -200,7 +200,8 @@ class GoldCorpus(object):
ignore_misaligned=False, ignore_misaligned=False,
): ):
""" Setting gold_preproc will result in creating a doc per sentence """ """ Setting gold_preproc will result in creating a doc per sentence """
for example in examples: for eg_dict in annotations:
example = Example.from_dict(eg_dict, doc=nlp.make_doc(eg_dict["text"]))
example_docs = [] example_docs = []
if gold_preproc: if gold_preproc:
split_examples = example.split_sents() split_examples = example.split_sents()

View File

@ -1,18 +1,69 @@
import numpy
from .annotation import TokenAnnotation, DocAnnotation from .annotation import TokenAnnotation, DocAnnotation
from .iob_utils import spans_from_biluo_tags, biluo_tags_from_offsets
from .align import Alignment from .align import Alignment
from ..errors import Errors, AlignmentError from ..errors import Errors, AlignmentError
from ..tokens import Doc from ..tokens import Doc
def annotations2doc(doc, doc_annot, tok_annot):
# TODO: Improve and test this
words = tok_annot.words or [tok.text for tok in doc]
fields = {
"tags": "TAG",
"pos": "POS",
"lemmas": "LEMMA",
"deps": "DEP",
}
attrs = []
values = []
for field, attr in fields.items():
value = getattr(tok_annot, field)
# Unset fields will be empty lists.
if value:
attrs.append(attr)
values.append([doc.vocab.strings.add(v) for v in value])
if tok_annot.heads:
attrs.append("HEAD")
values.append([h - i for i, h in enumerate(tok_annot.heads)])
output = Doc(doc.vocab, words=words)
if values:
array = numpy.array(values, dtype="uint64")
output = output.from_array(attrs, array.T)
if tok_annot.entities:
output.ents = spans_from_biluo_tags(output, tok_annot.entities)
doc.cats = dict(doc_annot.cats)
# TODO: Calculate token.ent_kb_id from links.
# We need to fix this and the doc.ents thing, both should be doc
# annotations.
return doc
class Example: class Example:
def __init__(self, doc=None, doc_annotation=None, token_annotation=None): def __init__(self, doc, doc_annotation=None, token_annotation=None):
""" Doc can either be text, or an actual Doc """ """ Doc can either be text, or an actual Doc """
if not isinstance(doc, Doc):
raise TypeError("Must pass Doc instance")
self.predicted = doc
self.doc = doc self.doc = doc
self.doc_annotation = doc_annotation if doc_annotation else DocAnnotation() self.doc_annotation = doc_annotation if doc_annotation else DocAnnotation()
self.token_annotation = ( self.token_annotation = (
token_annotation if token_annotation else TokenAnnotation() token_annotation if token_annotation else TokenAnnotation()
) )
self._alignment = None self._alignment = None
self.reference = annotations2doc(
self.doc,
self.doc_annotation,
self.token_annotation
)
@property
def x(self):
return self.predicted
@property
def y(self):
return self.reference
def _deprecated_get_gold(self, make_projective=False): def _deprecated_get_gold(self, make_projective=False):
from ..syntax.gold_parse import get_parses_from_example from ..syntax.gold_parse import get_parses_from_example
@ -24,6 +75,8 @@ class Example:
def from_dict(cls, example_dict, doc=None): def from_dict(cls, example_dict, doc=None):
if example_dict is None: if example_dict is None:
raise ValueError("Example.from_dict expected dict, received None") raise ValueError("Example.from_dict expected dict, received None")
if doc is None:
raise ValueError("Must pass doc")
# TODO: This is ridiculous... # TODO: This is ridiculous...
token_dict = example_dict.get("token_annotation", {}) token_dict = example_dict.get("token_annotation", {})
doc_dict = example_dict.get("doc_annotation", {}) doc_dict = example_dict.get("doc_annotation", {})
@ -34,6 +87,10 @@ class Example:
doc_dict[key] = value doc_dict[key] = value
else: else:
token_dict[key] = value token_dict[key] = value
if token_dict.get("entities"):
entities = token_dict["entities"]
if isinstance(entities[0], (list, tuple)):
token_dict["entities"] = biluo_tags_from_offsets(doc, entities)
token_annotation = TokenAnnotation.from_dict(token_dict) token_annotation = TokenAnnotation.from_dict(token_dict)
doc_annotation = DocAnnotation.from_dict(doc_dict) doc_annotation = DocAnnotation.from_dict(doc_dict)
return cls( return cls(
@ -45,8 +102,8 @@ class Example:
if self._alignment is None: if self._alignment is None:
if self.doc is None: if self.doc is None:
return None return None
spacy_words = [token.orth_ for token in self.doc] spacy_words = [token.orth_ for token in self.predicted]
gold_words = self.token_annotation.words gold_words = [token.orth_ for token in self.reference]
if gold_words == []: if gold_words == []:
gold_words = spacy_words gold_words = spacy_words
self._alignment = Alignment(spacy_words, gold_words) self._alignment = Alignment(spacy_words, gold_words)
@ -92,34 +149,6 @@ class Example:
output.append(gold_values[gold_i]) output.append(gold_values[gold_i])
return output return output
def set_token_annotation(
self,
ids=None,
words=None,
tags=None,
pos=None,
morphs=None,
lemmas=None,
heads=None,
deps=None,
entities=None,
sent_starts=None,
brackets=None,
):
self.token_annotation = TokenAnnotation(
ids=ids,
words=words,
tags=tags,
pos=pos,
morphs=morphs,
lemmas=lemmas,
heads=heads,
deps=deps,
entities=entities,
sent_starts=sent_starts,
brackets=brackets,
)
def set_doc_annotation(self, cats=None, links=None): def set_doc_annotation(self, cats=None, links=None):
if cats: if cats:
self.doc_annotation.cats = cats self.doc_annotation.cats = cats
@ -131,7 +160,6 @@ class Example:
sent_starts and return a list of the new Examples""" sent_starts and return a list of the new Examples"""
if not self.token_annotation.words: if not self.token_annotation.words:
return [self] return [self]
s_example = Example(doc=None, doc_annotation=self.doc_annotation)
s_ids, s_words, s_tags, s_pos, s_morphs = [], [], [], [], [] s_ids, s_words, s_tags, s_pos, s_morphs = [], [], [], [], []
s_lemmas, s_heads, s_deps, s_ents, s_sent_starts = [], [], [], [], [] s_lemmas, s_heads, s_deps, s_ents, s_sent_starts = [], [], [], [], []
s_brackets = [] s_brackets = []
@ -140,21 +168,25 @@ class Example:
split_examples = [] split_examples = []
for i in range(len(t.words)): for i in range(len(t.words)):
if i > 0 and t.sent_starts[i] == 1: if i > 0 and t.sent_starts[i] == 1:
s_example.set_token_annotation( split_examples.append(
ids=s_ids, Example(
words=s_words, doc=Doc(self.doc.vocab, words=s_words),
tags=s_tags, token_annotation=TokenAnnotation(
pos=s_pos, ids=s_ids,
morphs=s_morphs, words=s_words,
lemmas=s_lemmas, tags=s_tags,
heads=s_heads, pos=s_pos,
deps=s_deps, morphs=s_morphs,
entities=s_ents, lemmas=s_lemmas,
sent_starts=s_sent_starts, heads=s_heads,
brackets=s_brackets, deps=s_deps,
entities=s_ents,
sent_starts=s_sent_starts,
brackets=s_brackets,
),
doc_annotation=self.doc_annotation
)
) )
split_examples.append(s_example)
s_example = Example(doc=None, doc_annotation=self.doc_annotation)
s_ids, s_words, s_tags, s_pos, s_heads = [], [], [], [], [] s_ids, s_words, s_tags, s_pos, s_heads = [], [], [], [], []
s_deps, s_ents, s_morphs, s_lemmas = [], [], [], [] s_deps, s_ents, s_morphs, s_lemmas = [], [], [], []
s_sent_starts, s_brackets = [], [] s_sent_starts, s_brackets = [], []
@ -172,20 +204,25 @@ class Example:
for b_end, b_label in t.brackets_by_start.get(i, []): for b_end, b_label in t.brackets_by_start.get(i, []):
s_brackets.append((i - sent_start_i, b_end - sent_start_i, b_label)) s_brackets.append((i - sent_start_i, b_end - sent_start_i, b_label))
i += 1 i += 1
s_example.set_token_annotation( split_examples.append(
ids=s_ids, Example(
words=s_words, doc=Doc(self.doc.vocab, words=s_words),
tags=s_tags, token_annotation=TokenAnnotation(
pos=s_pos, ids=s_ids,
morphs=s_morphs, words=s_words,
lemmas=s_lemmas, tags=s_tags,
heads=s_heads, pos=s_pos,
deps=s_deps, morphs=s_morphs,
entities=s_ents, lemmas=s_lemmas,
sent_starts=s_sent_starts, heads=s_heads,
brackets=s_brackets, deps=s_deps,
entities=s_ents,
sent_starts=s_sent_starts,
brackets=s_brackets,
),
doc_annotation=self.doc_annotation
)
) )
split_examples.append(s_example)
return split_examples return split_examples
@classmethod @classmethod

View File

@ -76,12 +76,12 @@ def read_json_file(loc, docs_filter=None, limit=None):
yield json_data yield json_data
def json_to_examples(doc): def json_to_annotations(doc):
"""Convert an item in the JSON-formatted training data to the format """Convert an item in the JSON-formatted training data to the format
used by GoldParse. used by GoldParse.
doc (dict): One entry in the training data. doc (dict): One entry in the training data.
YIELDS (Example): The reformatted data - one training example per paragraph YIELDS (tuple): The reformatted data - one training example per paragraph
""" """
for paragraph in doc["paragraphs"]: for paragraph in doc["paragraphs"]:
example = {"text": paragraph.get("raw", None)} example = {"text": paragraph.get("raw", None)}

View File

@ -108,7 +108,7 @@ def preprocess_training_data(gold_data, label_freq_cutoff=30):
proj_token_dict = example.token_annotation.to_dict() proj_token_dict = example.token_annotation.to_dict()
proj_token_dict["heads"] = proj_heads proj_token_dict["heads"] = proj_heads
proj_token_dict["deps"] = deco_deps proj_token_dict["deps"] = deco_deps
new_example.set_token_annotation(**proj_token_dict) new_example.token_annotation = TokenAnnotation(**proj_token_dict)
preprocessed.append(new_example) preprocessed.append(new_example)
if label_freq_cutoff > 0: if label_freq_cutoff > 0:
return _filter_labels(preprocessed, label_freq_cutoff, freqs) return _filter_labels(preprocessed, label_freq_cutoff, freqs)
@ -216,6 +216,6 @@ def _filter_labels(examples, cutoff, freqs):
filtered_labels.append(label) filtered_labels.append(label)
filtered_token_dict = example.token_annotation.to_dict() filtered_token_dict = example.token_annotation.to_dict()
filtered_token_dict["deps"] = filtered_labels filtered_token_dict["deps"] = filtered_labels
new_example.set_token_annotation(**filtered_token_dict) new_example.token_annotation = TokenAnnotation(**filtered_token_dict)
filtered.append(new_example) filtered.append(new_example)
return filtered return filtered

View File

@ -3,7 +3,7 @@ import gc
import numpy import numpy
import copy import copy
from spacy.gold import Example from spacy.gold import Example, TokenAnnotation
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.en.stop_words import STOP_WORDS from spacy.lang.en.stop_words import STOP_WORDS
from spacy.lang.lex_attrs import is_stop from spacy.lang.lex_attrs import is_stop
@ -271,9 +271,16 @@ def test_issue1963(en_tokenizer):
@pytest.mark.parametrize("label", ["U-JOB-NAME"]) @pytest.mark.parametrize("label", ["U-JOB-NAME"])
def test_issue1967(label): def test_issue1967(label):
ner = EntityRecognizer(Vocab(), default_ner()) ner = EntityRecognizer(Vocab(), default_ner())
example = Example(doc=None) example = Example(
example.set_token_annotation( doc=Doc(ner.vocab, words=["word"]),
ids=[0], words=["word"], tags=["tag"], heads=[0], deps=["dep"], entities=[label] token_annotation=TokenAnnotation(
ids=[0],
words=["word"],
tags=["tag"],
heads=[0],
deps=["dep"],
entities=[label]
)
) )
ner.moves.get_actions(gold_parses=[example]) ner.moves.get_actions(gold_parses=[example])

View File

@ -95,6 +95,12 @@ def merged_dict():
} }
@pytest.fixture
def vocab():
nlp = English()
return nlp.vocab
def test_gold_biluo_U(en_vocab): def test_gold_biluo_U(en_vocab):
words = ["I", "flew", "to", "London", "."] words = ["I", "flew", "to", "London", "."]
spaces = [True, True, True, False, True] spaces = [True, True, True, False, True]
@ -475,8 +481,10 @@ def _train(train_data):
def test_split_sents(merged_dict): def test_split_sents(merged_dict):
nlp = English() nlp = English()
example = Example() example = Example.from_dict(
example.set_token_annotation(**merged_dict) merged_dict,
doc=Doc(nlp.vocab, words=merged_dict["words"])
)
assert len(get_parses_from_example( assert len(get_parses_from_example(
example, example,
merge=False, merge=False,
@ -506,13 +514,15 @@ def test_split_sents(merged_dict):
assert token_annotation_2.sent_starts == [1, 0, 0, 0] assert token_annotation_2.sent_starts == [1, 0, 0, 0]
def test_tuples_to_example(merged_dict): def test_tuples_to_example(vocab, merged_dict):
ex = Example()
ex.set_token_annotation(**merged_dict)
cats = {"TRAVEL": 1.0, "BAKING": 0.0} cats = {"TRAVEL": 1.0, "BAKING": 0.0}
ex.set_doc_annotation(cats=cats) merged_dict = dict(merged_dict)
merged_dict["cats"] = cats
ex = Example.from_dict(
merged_dict,
doc=Doc(vocab, words=merged_dict["words"])
)
ex_dict = ex.to_dict() ex_dict = ex.to_dict()
assert ex_dict["token_annotation"]["ids"] == merged_dict["ids"] assert ex_dict["token_annotation"]["ids"] == merged_dict["ids"]
assert ex_dict["token_annotation"]["words"] == merged_dict["words"] assert ex_dict["token_annotation"]["words"] == merged_dict["words"]
assert ex_dict["token_annotation"]["tags"] == merged_dict["tags"] assert ex_dict["token_annotation"]["tags"] == merged_dict["tags"]

View File

@ -1,12 +1,14 @@
from numpy.testing import assert_almost_equal, assert_array_almost_equal from numpy.testing import assert_almost_equal, assert_array_almost_equal
import pytest import pytest
from pytest import approx from pytest import approx
from spacy.gold import Example, GoldParse from spacy.gold import Example, GoldParse, TokenAnnotation
from spacy.gold.iob_utils import biluo_tags_from_offsets
from spacy.scorer import Scorer, ROCAUCScore from spacy.scorer import Scorer, ROCAUCScore
from spacy.scorer import _roc_auc_score, _roc_curve from spacy.scorer import _roc_auc_score, _roc_curve
from .util import get_doc from .util import get_doc
from spacy.lang.en import English from spacy.lang.en import English
test_las_apple = [ test_las_apple = [
[ [
"Apple is looking at buying U.K. startup for $ 1 billion", "Apple is looking at buying U.K. startup for $ 1 billion",
@ -134,8 +136,11 @@ def test_ner_per_type(en_vocab):
words=input_.split(" "), words=input_.split(" "),
ents=[[0, 1, "CARDINAL"], [2, 3, "CARDINAL"]], ents=[[0, 1, "CARDINAL"], [2, 3, "CARDINAL"]],
) )
ex = Example(doc=doc) entities = biluo_tags_from_offsets(doc, annot["entities"])
ex.set_token_annotation(entities=annot["entities"]) ex = Example(
doc=doc,
token_annotation=TokenAnnotation(entities=entities)
)
scorer.score(ex) scorer.score(ex)
results = scorer.scores results = scorer.scores
@ -155,8 +160,11 @@ def test_ner_per_type(en_vocab):
words=input_.split(" "), words=input_.split(" "),
ents=[[0, 1, "ORG"], [5, 6, "GPE"], [6, 7, "ORG"]], ents=[[0, 1, "ORG"], [5, 6, "GPE"], [6, 7, "ORG"]],
) )
ex = Example(doc=doc) entities = biluo_tags_from_offsets(doc, annot["entities"])
ex.set_token_annotation(entities=annot["entities"]) ex = Example(
doc=doc,
token_annotation=TokenAnnotation(entities=entities)
)
scorer.score(ex) scorer.score(ex)
results = scorer.scores results = scorer.scores

View File

@ -799,6 +799,8 @@ cdef class Doc:
cdef attr_id_t attr_id cdef attr_id_t attr_id
cdef TokenC* tokens = self.c cdef TokenC* tokens = self.c
cdef int length = len(array) cdef int length = len(array)
if length != len(self):
raise ValueError("Cannot set array values longer than the document.")
# Get set up for fast loading # Get set up for fast loading
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef int n_attrs = len(attrs) cdef int n_attrs = len(attrs)