Remove GoldParse

WIP on removing goldparse

Get ArcEager compiling after GoldParse excise

Update setup.py

Get spacy.syntax compiling after removing GoldParse

Rename NewExample -> Example and clean up

Clean html files

Start updating tests

Update Morphologizer
This commit is contained in:
Matthew Honnibal 2020-06-14 17:45:46 +02:00
parent d53723aa4f
commit 98ca14f577
19 changed files with 124 additions and 1016 deletions

View File

@ -24,7 +24,7 @@ Options.docstrings = True
PACKAGES = find_packages()
MOD_NAMES = [
"spacy.gold.align",
"spacy.gold.new_example",
"spacy.gold.example",
"spacy.parts_of_speech",
"spacy.strings",
"spacy.lexeme",
@ -37,7 +37,6 @@ MOD_NAMES = [
"spacy.syntax.stateclass",
"spacy.syntax._state",
"spacy.tokenizer",
"spacy.syntax.gold_parse",
"spacy.syntax.nn_parser",
"spacy.syntax._parser_model",
"spacy.syntax._beam_utils",
@ -123,7 +122,7 @@ class build_ext_subclass(build_ext, build_ext_options):
def clean(path):
for path in path.glob("**/*"):
if path.is_file() and path.suffix in (".so", ".cpp"):
if path.is_file() and path.suffix in (".so", ".cpp", ".html"):
print(f"Deleting {path.name}")
path.unlink()

View File

@ -1,150 +0,0 @@
from .iob_utils import biluo_tags_from_offsets
class TokenAnnotation:
def __init__(
self,
ids=None,
words=None,
tags=None,
pos=None,
morphs=None,
lemmas=None,
heads=None,
deps=None,
entities=None,
sent_starts=None,
brackets=None,
):
self.ids = ids if ids else []
self.words = words if words else []
self.tags = tags if tags else []
self.pos = pos if pos else []
self.morphs = morphs if morphs else []
self.lemmas = lemmas if lemmas else []
self.heads = heads if heads else []
self.deps = deps if deps else []
self.entities = entities if entities else []
self.sent_starts = sent_starts if sent_starts else []
self.brackets_by_start = {}
if brackets:
for b_start, b_end, b_label in brackets:
self.brackets_by_start.setdefault(b_start, []).append((b_end, b_label))
def get_field(self, field):
if field == "id":
return self.ids
elif field == "word":
return self.words
elif field == "tag":
return self.tags
elif field == "pos":
return self.pos
elif field == "morph":
return self.morphs
elif field == "lemma":
return self.lemmas
elif field == "head":
return self.heads
elif field == "dep":
return self.deps
elif field == "ner":
return self.entities
elif field == "sent_start":
return self.sent_starts
else:
raise ValueError(f"Unknown field: {field}")
@property
def brackets(self):
brackets = []
for start, ends_labels in self.brackets_by_start.items():
for end, label in ends_labels:
brackets.append((start, end, label))
return brackets
@classmethod
def from_dict(cls, token_dict):
return cls(
ids=token_dict.get("ids", None),
words=token_dict.get("words", None),
tags=token_dict.get("tags", None),
pos=token_dict.get("pos", None),
morphs=token_dict.get("morphs", None),
lemmas=token_dict.get("lemmas", None),
heads=token_dict.get("heads", None),
deps=token_dict.get("deps", None),
entities=token_dict.get("entities", None),
sent_starts=token_dict.get("sent_starts", None),
brackets=token_dict.get("brackets", None),
)
def to_dict(self):
return {
"ids": self.ids,
"words": self.words,
"tags": self.tags,
"pos": self.pos,
"morphs": self.morphs,
"lemmas": self.lemmas,
"heads": self.heads,
"deps": self.deps,
"entities": self.entities,
"sent_starts": self.sent_starts,
"brackets": self.brackets,
}
def get_id(self, i):
return self.ids[i] if i < len(self.ids) else i
def get_word(self, i):
return self.words[i] if i < len(self.words) else ""
def get_tag(self, i):
return self.tags[i] if i < len(self.tags) else "-"
def get_pos(self, i):
return self.pos[i] if i < len(self.pos) else ""
def get_morph(self, i):
return self.morphs[i] if i < len(self.morphs) else ""
def get_lemma(self, i):
return self.lemmas[i] if i < len(self.lemmas) else ""
def get_head(self, i):
return self.heads[i] if i < len(self.heads) else i
def get_dep(self, i):
return self.deps[i] if i < len(self.deps) else ""
def get_entity(self, i):
return self.entities[i] if i < len(self.entities) else "-"
def get_sent_start(self, i):
return self.sent_starts[i] if i < len(self.sent_starts) else None
def __str__(self):
return str(self.to_dict())
def __repr__(self):
return self.__str__()
class DocAnnotation:
def __init__(self, cats=None, links=None):
self.cats = cats if cats else {}
self.links = links if links else {}
@classmethod
def from_dict(cls, doc_dict):
return cls(cats=doc_dict.get("cats", None), links=doc_dict.get("links", None))
def to_dict(self):
return {"cats": self.cats, "links": self.links}
def __str__(self):
return str(self.to_dict())
def __repr__(self):
return self.__str__()

View File

@ -9,7 +9,7 @@ from .. import util
from ..errors import Errors, AlignmentError
from .gold_io import read_json_file, json_to_annotations
from .augment import make_orth_variants, add_noise
from .new_example import NewExample as Example
from .example import Example
class GoldCorpus(object):

View File

@ -2,7 +2,7 @@ from ..tokens.doc cimport Doc
from .align cimport Alignment
cdef class NewExample:
cdef class Example:
cdef readonly Doc x
cdef readonly Doc y
cdef readonly Alignment _alignment

View File

@ -1,261 +0,0 @@
import numpy
from .annotation import TokenAnnotation, DocAnnotation
from .iob_utils import spans_from_biluo_tags, biluo_tags_from_offsets
from .align import Alignment
from ..errors import Errors, AlignmentError
from ..tokens import Doc
def annotations2doc(doc, doc_annot, tok_annot):
# TODO: Improve and test this
words = tok_annot.words or [tok.text for tok in doc]
fields = {
"tags": "TAG",
"pos": "POS",
"lemmas": "LEMMA",
"deps": "DEP",
}
attrs = []
values = []
for field, attr in fields.items():
value = getattr(tok_annot, field)
# Unset fields will be empty lists.
if value:
attrs.append(attr)
values.append([doc.vocab.strings.add(v) for v in value])
if tok_annot.heads:
attrs.append("HEAD")
values.append([h - i for i, h in enumerate(tok_annot.heads)])
output = Doc(doc.vocab, words=words)
if values:
array = numpy.array(values, dtype="uint64")
output = output.from_array(attrs, array.T)
if tok_annot.entities:
output.ents = spans_from_biluo_tags(output, tok_annot.entities)
doc.cats = dict(doc_annot.cats)
# TODO: Calculate token.ent_kb_id from links.
# We need to fix this and the doc.ents thing, both should be doc
# annotations.
return doc
class Example:
def __init__(self, doc, doc_annotation=None, token_annotation=None):
""" Doc can either be text, or an actual Doc """
if not isinstance(doc, Doc):
raise TypeError("Must pass Doc instance")
self.predicted = doc
self.doc = doc
self.doc_annotation = doc_annotation if doc_annotation else DocAnnotation()
self.token_annotation = (
token_annotation if token_annotation else TokenAnnotation()
)
self._alignment = None
self.reference = annotations2doc(
self.doc,
self.doc_annotation,
self.token_annotation
)
@property
def x(self):
return self.predicted
@property
def y(self):
return self.reference
def _deprecated_get_gold(self, make_projective=False):
from ..syntax.gold_parse import get_parses_from_example
_, gold = get_parses_from_example(self, make_projective=make_projective)[0]
return gold
@classmethod
def from_dict(cls, example_dict, doc=None):
if example_dict is None:
raise ValueError("Example.from_dict expected dict, received None")
if doc is None:
raise ValueError("Must pass doc")
# TODO: This is ridiculous...
token_dict = example_dict.get("token_annotation", {})
doc_dict = example_dict.get("doc_annotation", {})
for key, value in example_dict.items():
if key in ("token_annotation", "doc_annotation"):
pass
elif key in ("cats", "links"):
doc_dict[key] = value
else:
token_dict[key] = value
if token_dict.get("entities"):
entities = token_dict["entities"]
if isinstance(entities[0], (list, tuple)):
token_dict["entities"] = biluo_tags_from_offsets(doc, entities)
token_annotation = TokenAnnotation.from_dict(token_dict)
doc_annotation = DocAnnotation.from_dict(doc_dict)
return cls(
doc=doc, doc_annotation=doc_annotation, token_annotation=token_annotation
)
@property
def alignment(self):
if self._alignment is None:
if self.doc is None:
return None
spacy_words = [token.orth_ for token in self.predicted]
gold_words = [token.orth_ for token in self.reference]
if gold_words == []:
gold_words = spacy_words
self._alignment = Alignment(spacy_words, gold_words)
return self._alignment
def to_dict(self):
""" Note that this method does NOT export the doc, only the annotations ! """
token_dict = self.token_annotation.to_dict()
doc_dict = self.doc_annotation.to_dict()
return {"token_annotation": token_dict, "doc_annotation": doc_dict}
@property
def text(self):
if self.doc is None:
return None
if isinstance(self.doc, Doc):
return self.doc.text
return self.doc
def get_aligned(self, field):
"""Return an aligned array for a token annotation field."""
if self.doc is None:
return self.token_annotation.get_field(field)
doc = self.doc
if field == "word":
return [token.orth_ for token in doc]
gold_values = self.token_annotation.get_field(field)
alignment = self.alignment
i2j_multi = alignment.i2j_multi
gold_to_cand = alignment.gold_to_cand
cand_to_gold = alignment.cand_to_gold
output = []
for i, gold_i in enumerate(cand_to_gold):
if doc[i].text.isspace():
output.append(None)
elif gold_i is None:
if i in i2j_multi:
output.append(gold_values[i2j_multi[i]])
else:
output.append(None)
else:
output.append(gold_values[gold_i])
return output
def set_doc_annotation(self, cats=None, links=None):
if cats:
self.doc_annotation.cats = cats
if links:
self.doc_annotation.links = links
def split_sents(self):
""" Split the token annotations into multiple Examples based on
sent_starts and return a list of the new Examples"""
if not self.token_annotation.words:
return [self]
s_ids, s_words, s_tags, s_pos, s_morphs = [], [], [], [], []
s_lemmas, s_heads, s_deps, s_ents, s_sent_starts = [], [], [], [], []
s_brackets = []
sent_start_i = 0
t = self.token_annotation
split_examples = []
for i in range(len(t.words)):
if i > 0 and t.sent_starts[i] == 1:
split_examples.append(
Example(
doc=Doc(self.doc.vocab, words=s_words),
token_annotation=TokenAnnotation(
ids=s_ids,
words=s_words,
tags=s_tags,
pos=s_pos,
morphs=s_morphs,
lemmas=s_lemmas,
heads=s_heads,
deps=s_deps,
entities=s_ents,
sent_starts=s_sent_starts,
brackets=s_brackets,
),
doc_annotation=self.doc_annotation
)
)
s_ids, s_words, s_tags, s_pos, s_heads = [], [], [], [], []
s_deps, s_ents, s_morphs, s_lemmas = [], [], [], []
s_sent_starts, s_brackets = [], []
sent_start_i = i
s_ids.append(t.get_id(i))
s_words.append(t.get_word(i))
s_tags.append(t.get_tag(i))
s_pos.append(t.get_pos(i))
s_morphs.append(t.get_morph(i))
s_lemmas.append(t.get_lemma(i))
s_heads.append(t.get_head(i) - sent_start_i)
s_deps.append(t.get_dep(i))
s_ents.append(t.get_entity(i))
s_sent_starts.append(t.get_sent_start(i))
for b_end, b_label in t.brackets_by_start.get(i, []):
s_brackets.append((i - sent_start_i, b_end - sent_start_i, b_label))
i += 1
split_examples.append(
Example(
doc=Doc(self.doc.vocab, words=s_words),
token_annotation=TokenAnnotation(
ids=s_ids,
words=s_words,
tags=s_tags,
pos=s_pos,
morphs=s_morphs,
lemmas=s_lemmas,
heads=s_heads,
deps=s_deps,
entities=s_ents,
sent_starts=s_sent_starts,
brackets=s_brackets,
),
doc_annotation=self.doc_annotation
)
)
return split_examples
@classmethod
def to_example_objects(cls, examples, make_doc=None, keep_raw_text=False):
"""
Return a list of Example objects, from a variety of input formats.
make_doc needs to be provided when the examples contain text strings and keep_raw_text=False
"""
if isinstance(examples, Example):
return [examples]
if isinstance(examples, tuple):
examples = [examples]
converted_examples = []
for ex in examples:
if isinstance(ex, Example):
converted_examples.append(ex)
# convert string to Doc to Example
elif isinstance(ex, str):
if keep_raw_text:
converted_examples.append(Example(doc=ex))
else:
doc = make_doc(ex)
converted_examples.append(Example(doc=doc))
# convert tuples to Example
elif isinstance(ex, tuple) and len(ex) == 2:
doc, gold = ex
# convert string to Doc
if isinstance(doc, str) and not keep_raw_text:
doc = make_doc(doc)
converted_examples.append(Example.from_dict(gold, doc=doc))
# convert Doc to Example
elif isinstance(ex, Doc):
converted_examples.append(Example(doc=ex))
else:
converted_examples.append(ex)
return converted_examples

View File

@ -21,7 +21,7 @@ cpdef Doc annotations2doc(Doc predicted, tok_annot, doc_annot):
return output
cdef class NewExample:
cdef class Example:
def __init__(self, Doc predicted, Doc reference, *, Alignment alignment=None):
""" Doc can either be text, or an actual Doc """
msg = "Example.__init__ got None for '{arg}'. Requires Doc."
@ -55,7 +55,7 @@ cdef class NewExample:
raise TypeError(f"Argument 1 should be Doc. Got {type(predicted)}")
example_dict = _fix_legacy_dict_data(predicted, example_dict)
tok_dict, doc_dict = _parse_example_dict_data(example_dict)
return NewExample(
return Example(
predicted,
annotations2doc(predicted, tok_dict, doc_dict)
)
@ -291,144 +291,3 @@ def _parse_links(vocab, words, links, entities):
ent_kb_ids[i] = true_kb_ids[0]
return ent_kb_ids
class Example:
def get_aligned(self, field):
"""Return an aligned array for a token annotation field."""
if self.doc is None:
return self.token_annotation.get_field(field)
doc = self.doc
if field == "word":
return [token.orth_ for token in doc]
gold_values = self.token_annotation.get_field(field)
alignment = self.alignment
i2j_multi = alignment.i2j_multi
gold_to_cand = alignment.gold_to_cand
cand_to_gold = alignment.cand_to_gold
output = []
for i, gold_i in enumerate(cand_to_gold):
if doc[i].text.isspace():
output.append(None)
elif gold_i is None:
if i in i2j_multi:
output.append(gold_values[i2j_multi[i]])
else:
output.append(None)
else:
output.append(gold_values[gold_i])
return output
def split_sents(self):
""" Split the token annotations into multiple Examples based on
sent_starts and return a list of the new Examples"""
if not self.token_annotation.words:
return [self]
s_ids, s_words, s_tags, s_pos, s_morphs = [], [], [], [], []
s_lemmas, s_heads, s_deps, s_ents, s_sent_starts = [], [], [], [], []
s_brackets = []
sent_start_i = 0
t = self.token_annotation
split_examples = []
for i in range(len(t.words)):
if i > 0 and t.sent_starts[i] == 1:
split_examples.append(
Example(
doc=Doc(self.doc.vocab, words=s_words),
token_annotation=TokenAnnotation(
ids=s_ids,
words=s_words,
tags=s_tags,
pos=s_pos,
morphs=s_morphs,
lemmas=s_lemmas,
heads=s_heads,
deps=s_deps,
entities=s_ents,
sent_starts=s_sent_starts,
brackets=s_brackets,
),
doc_annotation=self.doc_annotation
)
)
s_ids, s_words, s_tags, s_pos, s_heads = [], [], [], [], []
s_deps, s_ents, s_morphs, s_lemmas = [], [], [], []
s_sent_starts, s_brackets = [], []
sent_start_i = i
s_ids.append(t.get_id(i))
s_words.append(t.get_word(i))
s_tags.append(t.get_tag(i))
s_pos.append(t.get_pos(i))
s_morphs.append(t.get_morph(i))
s_lemmas.append(t.get_lemma(i))
s_heads.append(t.get_head(i) - sent_start_i)
s_deps.append(t.get_dep(i))
s_ents.append(t.get_entity(i))
s_sent_starts.append(t.get_sent_start(i))
for b_end, b_label in t.brackets_by_start.get(i, []):
s_brackets.append((i - sent_start_i, b_end - sent_start_i, b_label))
i += 1
split_examples.append(
Example(
doc=Doc(self.doc.vocab, words=s_words),
token_annotation=TokenAnnotation(
ids=s_ids,
words=s_words,
tags=s_tags,
pos=s_pos,
morphs=s_morphs,
lemmas=s_lemmas,
heads=s_heads,
deps=s_deps,
entities=s_ents,
sent_starts=s_sent_starts,
brackets=s_brackets,
),
doc_annotation=self.doc_annotation
)
)
return split_examples
@classmethod
def to_example_objects(cls, examples, make_doc=None, keep_raw_text=False):
"""
Return a list of Example objects, from a variety of input formats.
make_doc needs to be provided when the examples contain text strings and keep_raw_text=False
"""
if isinstance(examples, Example):
return [examples]
if isinstance(examples, tuple):
examples = [examples]
converted_examples = []
for ex in examples:
if isinstance(ex, Example):
converted_examples.append(ex)
# convert string to Doc to Example
elif isinstance(ex, str):
if keep_raw_text:
converted_examples.append(Example(doc=ex))
else:
doc = make_doc(ex)
converted_examples.append(Example(doc=doc))
# convert tuples to Example
elif isinstance(ex, tuple) and len(ex) == 2:
doc, gold = ex
# convert string to Doc
if isinstance(doc, str) and not keep_raw_text:
doc = make_doc(doc)
converted_examples.append(Example.from_dict(gold, doc=doc))
# convert Doc to Example
elif isinstance(ex, Doc):
converted_examples.append(Example(doc=ex))
else:
converted_examples.append(ex)
return converted_examples
def _deprecated_get_gold(self, make_projective=False):
from ..syntax.gold_parse import get_parses_from_example
_, gold = get_parses_from_example(self, make_projective=make_projective)[0]
return gold

View File

@ -92,10 +92,11 @@ class Morphologizer(Tagger):
guesses = scores.argmax(axis=1)
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
for ex in examples:
gold = ex._deprecated_get_gold()
for i in range(len(gold.morphs)):
pos = gold.pos[i] if i < len(gold.pos) else ""
morph = gold.morphs[i]
pos_tags = ex.get_aligned("POS")
morphs = ex.get_aligned("MORPH")
for i in range(len(morphs)):
pos = pos_tags[i]
morph = morphs[i]
feats = Morphology.feats_to_dict(morph)
if pos:
feats["POS"] = pos

View File

@ -20,7 +20,7 @@ from .defaults import default_nel, default_senter
from .functions import merge_subtokens
from ..language import Language, component
from ..syntax import nonproj
from ..gold.new_example import NewExample as Example
from ..gold.example import Example
from ..attrs import POS, ID
from ..util import link_vectors_to_models, create_default_optimizer
from ..parts_of_speech import X

View File

@ -10,6 +10,7 @@ import numpy
from ..typedefs cimport hash_t, class_t
from .transition_system cimport TransitionSystem, Transition
from .stateclass cimport StateC, StateClass
from ..gold.example cimport Example
from ..errors import Errors
@ -125,7 +126,7 @@ cdef class ParserBeam(object):
beam.scores[i][j] = 0
beam.costs[i][j] = 0
def _set_costs(self, Beam beam, NewExample example, int follow_gold=False):
def _set_costs(self, Beam beam, Example example, int follow_gold=False):
for i in range(beam.size):
state = StateClass.borrow(<StateC*>beam.at(i))
if not state.is_final():

View File

@ -3,12 +3,11 @@ from cymem.cymem cimport Pool
from .stateclass cimport StateClass
from ..typedefs cimport weight_t, attr_t
from .transition_system cimport TransitionSystem, Transition
from .gold_parse cimport GoldParseC
cdef class ArcEager(TransitionSystem):
pass
cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil
cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil

View File

@ -2,6 +2,7 @@
from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool
from thinc.extra.search cimport Beam
from libc.stdint cimport int32_t
from collections import defaultdict, Counter
import json
@ -13,6 +14,7 @@ from ..tokens.doc cimport Doc, set_children_from_heads
from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold.example cimport Example
from ..errors import Errors
from .nonproj import is_nonproj_tree
@ -49,7 +51,7 @@ MOVE_NAMES[BREAK] = 'B'
cdef enum:
HEAD_ON_STACK = 0
HEAD_IN_STACK = 0
HEAD_IN_BUFFER
IS_SENT_START
HEAD_UNKNOWN
@ -60,12 +62,15 @@ cdef struct GoldParseStateC:
attr_t* labels
int32_t* heads
int32_t* n_kids_in_buffer
int32_t* n_kids_on_stack
int32_t* n_kids_in_stack
int32_t length
int32_t stride
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example example) except *:
cdef GoldParseStateC gs
return gs
cdef int check_state_flag(char state_bits, char flag) nogil:
cdef int check_state_gold(char state_bits, char flag) nogil:
cdef char one = 1
return state_bits & (one << flag)
@ -78,27 +83,28 @@ cdef int set_state_flag(char state_bits, char flag, int value) nogil:
return state_bits & ~(one << flag)
cdef int is_head_on_stack(GoldParseStateC gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_ON_STACK)
cdef int is_head_in_stack(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_IN_STACK)
cdef int is_head_in_buffer(GoldParseStateC gold, int i) nogil:
cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER)
cdef int is_sent_start(GoldParseStateC gold, int i) nogil:
cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], IS_SENT_START)
cdef int is_head_unknown(GoldParseStateC gold, int i) nogil:
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
# Helper functions for the arc-eager oracle
cdef weight_t push_cost(StateClass stcls, const GoldParseStateC* gold, int target) nogil:
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
gold = <const GoldParseStateC*>_gold
cdef weight_t cost = 0
if is_head_in_stack(gold[0], target):
if is_head_in_stack(gold, target):
cost += 1
cost += gold.n_kids_in_buffer[target]
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
@ -106,9 +112,10 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseStateC* gold, int targe
return cost
cdef weight_t pop_cost(StateClass stcls, const GoldParseStateC* gold, int target) nogil:
cdef weight_t pop_cost(StateClass stcls, const void* _gold, int target) nogil:
gold = <const GoldParseStateC*>_gold
cdef weight_t cost = 0
if is_head_in_buffer(gold[0], target):
if is_head_in_buffer(gold, target):
cost += 1
cost += gold[0].n_kids_in_buffer[target]
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
@ -116,7 +123,8 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseStateC* gold, int target
return cost
cdef weight_t arc_cost(StateClass stcls, const GoldParseStateC* gold, int head, int child) nogil:
cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil:
gold = <const GoldParseStateC*>_gold
if arc_is_gold(gold, head, child):
return 0
elif stcls.H(child) == gold.heads[child]:
@ -129,7 +137,7 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseStateC* gold, int head,
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
if is_head_unknown(gold[0], child):
if is_head_unknown(gold, child):
return True
elif gold.heads[child] == head:
return True
@ -138,7 +146,7 @@ cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil:
if is_head_unknown(gold[0], child):
if is_head_unknown(gold, child):
return True
elif label == 0:
return True
@ -149,7 +157,7 @@ cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t
cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
return gold.heads[word] == word or is_head_unknown(gold[0], word)
return gold.heads[word] == word or is_head_unknown(gold, word)
cdef class Shift:
@ -169,11 +177,12 @@ cdef class Shift:
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
return push_cost(s, gold, s.B(0))
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
cdef inline weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
return 0
@ -196,14 +205,15 @@ cdef class Reduce:
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass st, const GoldParseStateC* gold) nogil:
cdef inline weight_t move_cost(StateClass st, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
s0 = st.S(0)
cost = pop_cost(st, gold, s0)
return_to_buffer = not st.has_head(s0)
if return_to_buffer:
# Decrement cost for the arcs we save, as we'll be putting this
# back to the buffer
if is_head_in_stack(gold[0], s0):
if is_head_in_stack(gold, s0):
cost -= 1
cost -= gold.n_kids_in_stack[s0]
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
@ -211,7 +221,7 @@ cdef class Reduce:
return cost
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
return 0
@ -230,12 +240,13 @@ cdef class LeftArc:
st.fast_forward()
@staticmethod
cdef inline weight_t cost(StateClass s, const void* gold, attr_t label) nogil:
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
if arc_is_gold(gold, s.S(0), s.B(0)):
return 0
elif s.c.shifted[s.B(0)]:
@ -244,7 +255,44 @@ cdef class LeftArc:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod
cdef weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
cdef class RightArc:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
# If there's (perhaps partial) parse pre-set, don't allow cycle.
if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
return 0
sent_start = st._sent[st.B_(0).l_edge].sent_start
return sent_start != 1 and st.H(st.S(0)) != st.B(0)
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.S(0), st.B(0), label)
st.push()
st.fast_forward()
@staticmethod
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
if arc_is_gold(gold, s.S(0), s.B(0)):
return 0
elif s.c.shifted[s.B(0)]:
return push_cost(s, gold, s.B(0))
else:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod
cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
@ -276,13 +324,14 @@ cdef class Break:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
cdef weight_t cost = 0
cdef int i, j, S_i, B_i
for i in range(s.stack_depth()):
S_i = s.S(i)
cost += gold.n_kids_in_buffer[S_i]
if is_head_in_buffer(gold[0], S_i):
if is_head_in_buffer(gold, S_i):
cost += 1
# Check for sentence boundary --- if it's here, we can't have any deps
# between stack and buffer, so rest of action is irrelevant.
@ -294,15 +343,15 @@ cdef class Break:
return cost + 1
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
return 0
cdef int _get_root(int word, const GoldParseStateC* gold) nogil:
if is_head_unset(gold[0], word):
if is_head_unknown(gold, word):
return -1
while gold.heads[word] != word and word >= 0:
word = gold.heads[word]
if is_head_unset(gold[0], word):
if is_head_unknown(gold, word):
return -1
else:
return word
@ -378,7 +427,7 @@ cdef class ArcEager(TransitionSystem):
def action_types(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def get_cost(self, StateClass state, NewExample gold, action):
def get_cost(self, StateClass state, Example gold, action):
raise NotImplementedError
def transition(self, StateClass state, action):
@ -505,7 +554,7 @@ cdef class ArcEager(TransitionSystem):
output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, NewExample example) except -1:
StateClass stcls, Example example) except -1:
cdef Pool mem = Pool()
gold_state = create_gold_state(mem, stcls, example)
cdef int i, move
@ -527,8 +576,8 @@ cdef class ArcEager(TransitionSystem):
label_cost_funcs[RIGHT] = RightArc.label_cost
label_cost_funcs[BREAK] = Break.label_cost
cdef attr_t* labels = gold.c.labels
cdef int* heads = gold.c.heads
cdef attr_t* labels = gold_state.labels
cdef int32_t* heads = gold_state.heads
n_gold = 0
for i in range(self.n_moves):
@ -537,18 +586,18 @@ cdef class ArcEager(TransitionSystem):
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == 9000:
move_costs[move] = move_cost_funcs[move](stcls, gold_state)
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, gold_state, label)
move_costs[move] = move_cost_funcs[move](stcls, &gold_state)
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold_state, label)
n_gold += costs[i] <= 0
else:
is_valid[i] = False
costs[i] = 9000
if n_gold < 1:
# Check projectivity --- leading cause
if is_nonproj_tree(gold.heads):
if is_nonproj_tree(example.get_field("HEAD")):
raise ValueError(Errors.E020)
else:
failure_state = stcls.print_state(gold.words)
failure_state = stcls.print_state([t.text for t in example])
raise ValueError(Errors.E021.format(n_actions=self.n_moves,
state=failure_state))

View File

@ -1,39 +0,0 @@
from cymem.cymem cimport Pool
from .transition_system cimport Transition
from ..typedefs cimport attr_t
cdef struct GoldParseC:
int* tags
int* heads
int* has_dep
int* sent_start
attr_t* labels
int** brackets
Transition* ner
cdef class GoldParse:
cdef Pool mem
cdef GoldParseC c
cdef readonly object orig
cdef int length
cdef public int loss
cdef public list words
cdef public list tags
cdef public list pos
cdef public list morphs
cdef public list lemmas
cdef public list sent_starts
cdef public list heads
cdef public list labels
cdef public dict orths
cdef public list ner
cdef public dict brackets
cdef public dict cats
cdef public dict links
cdef readonly list cand_to_gold
cdef readonly list gold_to_cand

View File

@ -1,346 +0,0 @@
# cython: profile=True
import re
import random
import numpy
import tempfile
import shutil
import itertools
from pathlib import Path
import srsly
import warnings
from .. import util
from . import nonproj
from ..tokens import Doc, Span
from ..errors import Errors, AlignmentError, Warnings
from ..gold.annotation import TokenAnnotation
from ..gold.iob_utils import offsets_from_biluo_tags, biluo_tags_from_offsets
from ..gold.align import align
punct_re = re.compile(r"\W")
def is_punct_label(label):
return label == "P" or label.lower() == "punct"
def get_parses_from_example(
example, merge=True, vocab=None, make_projective=True, ignore_misaligned=False
):
"""Return a list of (doc, GoldParse) objects.
If merge is set to True, keep all Token annotations as one big list."""
# merge == do not modify Example
if merge:
examples = [example]
else:
# not merging: one GoldParse per sentence, defining docs with the words
# from each sentence
examples = example.split_sents()
outputs = []
for eg in examples:
eg_dict = eg.to_dict()
try:
gp = GoldParse.from_annotation(
eg.predicted,
eg_dict["doc_annotation"],
eg_dict["token_annotation"],
make_projective=make_projective
)
except AlignmentError:
if ignore_misaligned:
gp = None
else:
raise
outputs.append((eg.predicted, gp))
return outputs
cdef class GoldParse:
"""Collection for training annotations.
DOCS: https://spacy.io/api/goldparse
"""
@classmethod
def from_annotation(cls, doc, doc_annotation, token_annotation, make_projective=False):
return cls(
doc,
words=token_annotation["words"],
tags=token_annotation["tags"],
pos=token_annotation["pos"],
morphs=token_annotation["morphs"],
lemmas=token_annotation["lemmas"],
heads=token_annotation["heads"],
deps=token_annotation["deps"],
entities=token_annotation["entities"],
sent_starts=token_annotation["sent_starts"],
cats=doc_annotation["cats"],
links=doc_annotation["links"],
make_projective=make_projective
)
def get_token_annotation(self):
ids = None
if self.words:
ids = list(range(len(self.words)))
return TokenAnnotation(ids=ids, words=self.words, tags=self.tags,
pos=self.pos, morphs=self.morphs,
lemmas=self.lemmas, heads=self.heads,
deps=self.labels, entities=self.ner,
sent_starts=self.sent_starts)
def __init__(self, doc, words=None, tags=None, pos=None, morphs=None,
lemmas=None, heads=None, deps=None, entities=None,
sent_starts=None, make_projective=False, cats=None,
links=None):
"""Create a GoldParse. The fields will not be initialized if len(doc) is zero.
doc (Doc): The document the annotations refer to.
words (iterable): A sequence of unicode word strings.
tags (iterable): A sequence of strings, representing tag annotations.
pos (iterable): A sequence of strings, representing UPOS annotations.
morphs (iterable): A sequence of strings, representing morph
annotations.
lemmas (iterable): A sequence of strings, representing lemma
annotations.
heads (iterable): A sequence of integers, representing syntactic
head offsets.
deps (iterable): A sequence of strings, representing the syntactic
relation types.
entities (iterable): A sequence of named entity annotations, either as
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
representing the entity positions.
sent_starts (iterable): A sequence of sentence position tags, 1 for
the first word in a sentence, 0 for all others.
cats (dict): Labels for text classification. Each key in the dictionary
may be a string or an int, or a `(start_char, end_char, label)`
tuple, indicating that the label is applied to only part of the
document (usually a sentence). Unlike entity annotations, label
annotations can overlap, i.e. a single word can be covered by
multiple labelled spans. The TextCategorizer component expects
true examples of a label to have the value 1.0, and negative
examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels
will be zero.
links (dict): A dict with `(start_char, end_char)` keys,
and the values being dicts with kb_id:value entries,
representing the external IDs in a knowledge base (KB)
mapped to either 1.0 or 0.0, indicating positive and
negative examples respectively.
RETURNS (GoldParse): The newly constructed object.
"""
self.mem = Pool()
self.loss = 0
self.length = len(doc)
self.cats = {} if cats is None else dict(cats)
self.links = {} if links is None else dict(links)
# temporary doc for aligning entity annotation
entdoc = None
# avoid allocating memory if the doc does not contain any tokens
if self.length == 0:
self.words = []
self.tags = []
self.heads = []
self.labels = []
self.ner = []
self.morphs = []
# set a minimal orig so that the scorer can score an empty doc
self.orig = TokenAnnotation(ids=[])
else:
if not words:
words = [token.text for token in doc]
if not tags:
tags = [None for _ in words]
if not pos:
pos = [None for _ in words]
if not morphs:
morphs = [None for _ in words]
if not lemmas:
lemmas = [None for _ in words]
if not heads:
heads = [None for _ in words]
if not deps:
deps = [None for _ in words]
if not sent_starts:
sent_starts = [None for _ in words]
if entities is None:
entities = ["-" for _ in words]
elif len(entities) == 0:
entities = ["O" for _ in words]
else:
# Translate the None values to '-', to make processing easier.
# See Issue #2603
entities = [(ent if ent is not None else "-") for ent in entities]
if not isinstance(entities[0], str):
# Assume we have entities specified by character offset.
# Create a temporary Doc corresponding to provided words
# (to preserve gold tokenization) and text (to preserve
# character offsets).
entdoc_words, entdoc_spaces = util.get_words_and_spaces(words, doc.text)
entdoc = Doc(doc.vocab, words=entdoc_words, spaces=entdoc_spaces)
entdoc_entities = biluo_tags_from_offsets(entdoc, entities)
# There may be some additional whitespace tokens in the
# temporary doc, so check that the annotations align with
# the provided words while building a list of BILUO labels.
entities = []
words_offset = 0
for i in range(len(entdoc_words)):
if words[i + words_offset] == entdoc_words[i]:
entities.append(entdoc_entities[i])
else:
words_offset -= 1
if len(entities) != len(words):
warnings.warn(Warnings.W029.format(text=doc.text))
entities = ["-" for _ in words]
# These are filled by the tagger/parser/entity recogniser
self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.labels = <attr_t*>self.mem.alloc(len(doc), sizeof(attr_t))
self.c.has_dep = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
self.words = [None] * len(doc)
self.tags = [None] * len(doc)
self.pos = [None] * len(doc)
self.morphs = [None] * len(doc)
self.lemmas = [None] * len(doc)
self.heads = [None] * len(doc)
self.labels = [None] * len(doc)
self.ner = [None] * len(doc)
self.sent_starts = [None] * len(doc)
# This needs to be done before we align the words
if make_projective and any(heads) and any(deps) :
heads, deps = nonproj.projectivize(heads, deps)
# Do many-to-one alignment for misaligned tokens.
# If we over-segment, we'll have one gold word that covers a sequence
# of predicted words
# If we under-segment, we'll have one predicted word that covers a
# sequence of gold words.
# If we "mis-segment", we'll have a sequence of predicted words covering
# a sequence of gold words. That's many-to-many -- we don't do that
# except for NER spans where the start and end can be aligned.
cost, i2j, j2i, i2j_multi, j2i_multi = align([t.orth_ for t in doc], words)
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
self.orig = TokenAnnotation(ids=list(range(len(words))),
words=words, tags=tags, pos=pos, morphs=morphs,
lemmas=lemmas, heads=heads, deps=deps, entities=entities,
sent_starts=sent_starts, brackets=[])
for i, gold_i in enumerate(self.cand_to_gold):
if doc[i].text.isspace():
self.words[i] = doc[i].text
self.tags[i] = "_SP"
self.pos[i] = "SPACE"
self.morphs[i] = None
self.lemmas[i] = None
self.heads[i] = None
self.labels[i] = None
self.ner[i] = None
self.sent_starts[i] = 0
if gold_i is None:
if i in i2j_multi:
self.words[i] = words[i2j_multi[i]]
self.tags[i] = tags[i2j_multi[i]]
self.pos[i] = pos[i2j_multi[i]]
self.morphs[i] = morphs[i2j_multi[i]]
self.lemmas[i] = lemmas[i2j_multi[i]]
self.sent_starts[i] = sent_starts[i2j_multi[i]]
is_last = i2j_multi[i] != i2j_multi.get(i+1)
# Set next word in multi-token span as head, until last
if not is_last:
self.heads[i] = i+1
self.labels[i] = "subtok"
else:
head_i = heads[i2j_multi[i]]
if head_i:
self.heads[i] = self.gold_to_cand[head_i]
self.labels[i] = deps[i2j_multi[i]]
ner_tag = entities[i2j_multi[i]]
# Assign O/- for many-to-one O/- NER tags
if ner_tag in ("O", "-"):
self.ner[i] = ner_tag
else:
self.words[i] = words[gold_i]
self.tags[i] = tags[gold_i]
self.pos[i] = pos[gold_i]
self.morphs[i] = morphs[gold_i]
self.lemmas[i] = lemmas[gold_i]
self.sent_starts[i] = sent_starts[gold_i]
if heads[gold_i] is None:
self.heads[i] = None
else:
self.heads[i] = self.gold_to_cand[heads[gold_i]]
self.labels[i] = deps[gold_i]
self.ner[i] = entities[gold_i]
# Assign O/- for one-to-many O/- NER tags
for j, cand_j in enumerate(self.gold_to_cand):
if cand_j is None:
if j in j2i_multi:
i = j2i_multi[j]
ner_tag = entities[j]
if ner_tag in ("O", "-"):
self.ner[i] = ner_tag
# If there is entity annotation and some tokens remain unaligned,
# align all entities at the character level to account for all
# possible token misalignments within the entity spans
if any([e not in ("O", "-") for e in entities]) and None in self.ner:
# If the temporary entdoc wasn't created above, initialize it
if not entdoc:
entdoc_words, entdoc_spaces = util.get_words_and_spaces(words, doc.text)
entdoc = Doc(doc.vocab, words=entdoc_words, spaces=entdoc_spaces)
# Get offsets based on gold words and BILUO entities
entdoc_offsets = offsets_from_biluo_tags(entdoc, entities)
aligned_offsets = []
aligned_spans = []
# Filter offsets to identify those that align with doc tokens
for offset in entdoc_offsets:
span = doc.char_span(offset[0], offset[1])
if span and not span.text.isspace():
aligned_offsets.append(offset)
aligned_spans.append(span)
# Convert back to BILUO for doc tokens and assign NER for all
# aligned spans
biluo_tags = biluo_tags_from_offsets(doc, aligned_offsets, missing=None)
for span in aligned_spans:
for i in range(span.start, span.end):
self.ner[i] = biluo_tags[i]
# Prevent whitespace that isn't within entities from being tagged as
# an entity.
for i in range(len(self.ner)):
if self.tags[i] == "_SP":
prev_ner = self.ner[i-1] if i >= 1 else None
next_ner = self.ner[i+1] if (i+1) < len(self.ner) else None
if prev_ner == "O" or next_ner == "O":
self.ner[i] = "O"
cycle = nonproj.contains_cycle(self.heads)
if cycle is not None:
raise ValueError(Errors.E069.format(cycle=cycle,
cycle_tokens=" ".join([f"'{self.words[tok_id]}'" for tok_id in cycle]),
doc_tokens=" ".join(words[:50])))
def __len__(self):
"""Get the number of gold-standard tokens.
RETURNS (int): The number of gold-standard tokens.
"""
return self.length
@property
def is_projective(self):
"""Whether the provided syntactic annotations form a projective
dependency tree.
"""
return not nonproj.is_nonproj_tree(self.heads)

View File

@ -11,7 +11,6 @@ from ..lexeme cimport Lexeme
from ..attrs cimport IS_SPACE
from ..errors import Errors
from .gold_parse cimport GoldParseC
cdef enum:
@ -35,6 +34,9 @@ MOVE_NAMES[OUT] = 'O'
MOVE_NAMES[ISNT] = 'x'
cdef struct GoldNERStateC:
Transition* ner
cdef do_func_t[N_MOVES] do_funcs
@ -293,7 +295,7 @@ cdef class Begin:
@staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
@ -357,7 +359,7 @@ cdef class In:
@staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
gold = <GoldNERStateC*>_gold
move = IN
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
cdef int g_act = gold.ner[s.B(0)].move
@ -424,7 +426,7 @@ cdef class Last:
@staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
gold = <GoldNERStateC*>_gold
move = LAST
cdef int g_act = gold.ner[s.B(0)].move
@ -493,7 +495,7 @@ cdef class Unit:
@staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
@ -534,7 +536,7 @@ cdef class Out:
@staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <GoldParseC*>_gold
gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label

View File

@ -29,8 +29,8 @@ from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport Transition
from . cimport _beam_utils
from ..gold.example cimport Example
from ..gold import Example
from ..util import link_vectors_to_models, create_default_optimizer, registry
from ..compat import copy_array
from ..errors import Errors, Warnings
@ -39,6 +39,10 @@ from . import _beam_utils
from . import nonproj
def get_parses_from_example(example, merge=False, vocab=None):
# TODO: This is just a temporary shim to make the refactor easier.
return [(example.predicted, example)]
cdef class Parser:
"""
Base class of the DependencyParser and EntityRecognizer.
@ -572,7 +576,7 @@ cdef class Parser:
def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
cdef StateClass state
cdef NewExample example
cdef Example example
cdef Pool mem = Pool()
cdef int i

View File

@ -5,7 +5,7 @@ from ..structs cimport TokenC
from ..strings cimport StringStore
from .stateclass cimport StateClass
from ._state cimport StateC
from ..gold.new_example cimport NewExample
from ..gold.example cimport Example
cdef struct Transition:
@ -54,4 +54,4 @@ cdef class TransitionSystem:
cdef int set_valid(self, int* output, const StateC* st) nogil
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass state, NewExample example) except -1
StateClass state, Example example) except -1

View File

@ -87,14 +87,14 @@ cdef class TransitionSystem:
beams.append(beam)
return beams
def get_oracle_sequence(self, NewExample example):
def get_oracle_sequence(self, Example example):
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
cdef StateClass state = StateClass(doc, offset=0)
cdef StateClass state = StateClass(example.predicted, offset=0)
self.initialize_state(state.c)
history = []
while not state.is_final():
@ -148,18 +148,8 @@ cdef class TransitionSystem:
is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, NewExample example) except -1:
cdef int i
self.set_valid(is_valid, stcls.c)
cdef int n_gold = 0
for i in range(self.n_moves):
if is_valid[i]:
costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
n_gold += costs[i] <= 0
else:
costs[i] = 9000
if n_gold <= 0:
raise ValueError(Errors.E024)
StateClass stcls, Example example) except -1:
raise NotImplementedError
def get_class_name(self, int clas):
act = self.c[clas]

View File

@ -2,7 +2,7 @@ from spacy.errors import AlignmentError
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
from spacy.gold import spans_from_biluo_tags, iob_to_biluo, align
from spacy.gold import GoldCorpus, docs_to_json, DocAnnotation
from spacy.gold.new_example import NewExample as Example
from spacy.gold.example import Example
from spacy.lang.en import English
from spacy.syntax.nonproj import is_nonproj_tree
from spacy.syntax.gold_parse import GoldParse, get_parses_from_example

View File

@ -1,5 +1,5 @@
import pytest
from spacy.gold.new_example import NewExample as Example
from spacy.gold.example as Example
from spacy.tokens import Doc
from spacy.vocab import Vocab