mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix bug in Parser.labels and add test (#4275)
This commit is contained in:
parent
af93997993
commit
8ebc3711dc
|
@ -1063,7 +1063,7 @@ cdef class DependencyParser(Parser):
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
# Get the labels from the model by looking at the available moves
|
# Get the labels from the model by looking at the available moves
|
||||||
return tuple(set(move.split("-")[1] for move in self.move_names))
|
return tuple(set(move.split("-")[1] for move in self.move_names if "-" in move))
|
||||||
|
|
||||||
|
|
||||||
cdef class EntityRecognizer(Parser):
|
cdef class EntityRecognizer(Parser):
|
||||||
|
|
|
@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly():
|
||||||
assert ner1.moves.n_moves == ner2.moves.n_moves
|
assert ner1.moves.n_moves == ner2.moves.n_moves
|
||||||
for i in range(ner1.moves.n_moves):
|
for i in range(ner1.moves.n_moves):
|
||||||
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
|
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
|
||||||
|
)
|
||||||
|
def test_add_label_get_label(pipe_cls, n_moves):
|
||||||
|
"""Test that added labels are returned correctly. This test was added to
|
||||||
|
test for a bug in DependencyParser.labels that'd cause it to fail when
|
||||||
|
splitting the move names.
|
||||||
|
"""
|
||||||
|
labels = ["A", "B", "C"]
|
||||||
|
pipe = pipe_cls(Vocab())
|
||||||
|
for label in labels:
|
||||||
|
pipe.add_label(label)
|
||||||
|
assert len(pipe.move_names) == len(labels) * n_moves
|
||||||
|
pipe_labels = sorted(list(pipe.labels))
|
||||||
|
assert pipe_labels == labels
|
||||||
|
|
Loading…
Reference in New Issue
Block a user