This commit is contained in:
svlandeg 2020-06-19 11:31:01 +02:00
parent c705a28438
commit 25b0674320
4 changed files with 7 additions and 19 deletions

View File

@ -646,20 +646,6 @@ class Language(object):
sgd(W, dW, key=key)
return losses
def preprocess_gold(self, examples):
"""Can be called before training to pre-process gold data. By default,
it handles nonprojectivity and adds missing tags to the tag map.
examples (iterable): `Example` objects.
YIELDS (tuple): `Example` objects.
"""
# TODO: This is deprecated right?
for name, proc in self.pipeline:
if hasattr(proc, "preprocess_gold"):
examples = proc.preprocess_gold(examples)
for eg in examples:
yield eg
def begin_training(self, get_examples=None, sgd=None, component_cfg=None, **cfg):
"""Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager.

View File

@ -459,9 +459,9 @@ cdef class ArcEager(TransitionSystem):
actions[RIGHT][label] = 1
actions[REDUCE][label] = 1
for example in kwargs.get('gold_parses', []):
heads, labels = nonproj.projectivize(example.token_annotation.heads,
example.token_annotation.deps)
for child, head, label in zip(example.token_annotation.ids, heads, labels):
heads, labels = nonproj.projectivize(example.get_aligned("HEAD"),
example.get_aligned("DEP"))
for child, head, label in zip(example.get_aligned("ID"), heads, labels):
if label.upper() == 'ROOT' :
label = 'ROOT'
if head == child:

View File

@ -78,8 +78,8 @@ def is_decorated(label):
def count_decorated_labels(gold_data):
freqs = {}
for example in gold_data:
proj_heads, deco_deps = projectivize(example.token_annotation.heads,
example.token_annotation.deps)
proj_heads, deco_deps = projectivize(example.get_aligned("HEAD"),
example.get_aligned("DEP"))
# set the label to ROOT for each root dependent
deco_deps = ['ROOT' if head == i else deco_deps[i]
for i, head in enumerate(proj_heads)]

View File

@ -497,6 +497,8 @@ def test_split_sents(merged_dict):
Doc(nlp.vocab, words=merged_dict["words"], spaces=merged_dict["spaces"]),
merged_dict
)
assert example.text == "Hi there everyone It is just me"
assert len(get_parses_from_example(
example,
merge=False,