mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Fix bug in Parser.labels and add test (#4275)
This commit is contained in:
parent
af93997993
commit
8ebc3711dc
|
@ -1063,7 +1063,7 @@ cdef class DependencyParser(Parser):
|
|||
@property
|
||||
def labels(self):
|
||||
# Get the labels from the model by looking at the available moves
|
||||
return tuple(set(move.split("-")[1] for move in self.move_names))
|
||||
return tuple(set(move.split("-")[1] for move in self.move_names if "-" in move))
|
||||
|
||||
|
||||
cdef class EntityRecognizer(Parser):
|
||||
|
|
|
@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly():
|
|||
assert ner1.moves.n_moves == ner2.moves.n_moves
|
||||
for i in range(ner1.moves.n_moves):
|
||||
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
|
||||
)
|
||||
def test_add_label_get_label(pipe_cls, n_moves):
|
||||
"""Test that added labels are returned correctly. This test was added to
|
||||
test for a bug in DependencyParser.labels that'd cause it to fail when
|
||||
splitting the move names.
|
||||
"""
|
||||
labels = ["A", "B", "C"]
|
||||
pipe = pipe_cls(Vocab())
|
||||
for label in labels:
|
||||
pipe.add_label(label)
|
||||
assert len(pipe.move_names) == len(labels) * n_moves
|
||||
pipe_labels = sorted(list(pipe.labels))
|
||||
assert pipe_labels == labels
|
||||
|
|
Loading…
Reference in New Issue
Block a user