formatting

This commit is contained in:
svlandeg 2021-01-12 17:28:41 +01:00
parent a581d82f33
commit 5b598bd1d5
3 changed files with 7 additions and 5 deletions

View File

@ -9,6 +9,7 @@ from ...typedefs cimport hash_t, attr_t
from ...strings cimport hash_string from ...strings cimport hash_string
from ...structs cimport TokenC from ...structs cimport TokenC
from ...tokens.doc cimport Doc, set_children_from_heads from ...tokens.doc cimport Doc, set_children_from_heads
from ...tokens.token import MISSING_DEP_
from ...training.example cimport Example from ...training.example cimport Example
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC, ArcC from ._state cimport StateC, ArcC
@ -195,7 +196,8 @@ cdef class ArcEagerGold:
def __init__(self, ArcEager moves, StateClass stcls, Example example): def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool() self.mem = Pool()
heads, labels = example.get_aligned_parse(projectivize=True) heads, labels = example.get_aligned_parse(projectivize=True)
labels = [example.x.vocab.strings.add(label) if label is not None else 0 for label in labels] labels = [label if label is not None else MISSING_DEP_ for label in labels]
labels = [example.x.vocab.strings.add(label) for label in labels]
sent_starts = example.get_aligned_sent_starts() sent_starts = example.get_aligned_sent_starts()
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts)) assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts) self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)

View File

@ -259,8 +259,8 @@ def test_missing_head_dep(en_vocab):
deps = ["nsubj", "ROOT", "dobj", "cc", "conj", None] deps = ["nsubj", "ROOT", "dobj", "cc", "conj", None]
words = ["I", "like", "London", "and", "Berlin", "."] words = ["I", "like", "London", "and", "Berlin", "."]
doc = Doc(en_vocab, words=words, heads=heads, deps=deps) doc = Doc(en_vocab, words=words, heads=heads, deps=deps)
pred_has_heads = [t.has_head() for t in doc] pred_has_heads = [t.has_head() for t in doc]
pred_deps = [t.dep_ for t in doc] pred_deps = [t.dep_ for t in doc]
assert pred_has_heads == [True, True, True, True, True, False] assert pred_has_heads == [True, True, True, True, True, False]
assert pred_deps == ["nsubj", "ROOT", "dobj", "cc", "conj", MISSING_DEP_] assert pred_deps == ["nsubj", "ROOT", "dobj", "cc", "conj", MISSING_DEP_]
example = Example.from_dict(doc, {"heads": heads, "deps": deps}) example = Example.from_dict(doc, {"heads": heads, "deps": deps})
@ -271,4 +271,4 @@ def test_missing_head_dep(en_vocab):
assert ref_has_heads == [True, True, True, True, True, False] assert ref_has_heads == [True, True, True, True, True, False]
aligned_heads, aligned_deps = example.get_aligned_parse(projectivize=True) aligned_heads, aligned_deps = example.get_aligned_parse(projectivize=True)
assert aligned_heads[5] == ref_heads[5] assert aligned_heads[5] == ref_heads[5]
assert aligned_deps[5] == MISSING_DEP_ assert aligned_deps[5] == MISSING_DEP_

View File

@ -265,7 +265,7 @@ def test_Example_from_dict_sentences():
assert len(list(ex.reference.sents)) == 1 assert len(list(ex.reference.sents)) == 1
def test_Example_from_dict_with_parse(): def test_Example_missing_deps():
vocab = Vocab() vocab = Vocab()
words = ["I", "like", "London", "and", "Berlin", "."] words = ["I", "like", "London", "and", "Berlin", "."]
deps = ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"] deps = ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"]