account for NER labels with a hyphen in the name (#10960)

* account for NER labels with a hyphen in the name

* cleanup

* fix docstring

* add return type to helper method

* shorter method and few more occurrences

* user helper method across repo

* fix circular import

* partial revert to avoid circular import
This commit is contained in:
Sofie Van Landeghem 2022-06-17 20:02:37 +01:00 committed by GitHub
parent 6313787fb6
commit eaeca5eb6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 48 additions and 21 deletions

View File

@ -10,7 +10,7 @@ import math
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli
from ..training import Example
from ..training import Example, remove_bilu_prefix
from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining
from ..pipeline._parser_internals import nonproj
@ -758,9 +758,9 @@ def _compile_gold(
# "Illegal" whitespace entity
data["ws_ents"] += 1
if label.startswith(("B-", "U-")):
combined_label = label.split("-")[1]
combined_label = remove_bilu_prefix(label)
data["ner"][combined_label] += 1
if sent_starts[i] == True and label.startswith(("I-", "L-")):
if sent_starts[i] and label.startswith(("I-", "L-")):
data["boundary_cross_ents"] += 1
elif label == "-":
data["ner"]["-"] += 1
@ -908,7 +908,7 @@ def _get_examples_without_label(
for eg in data:
if component == "ner":
labels = [
label.split("-")[1]
remove_bilu_prefix(label)
for label in eg.get_aligned_ner()
if label not in ("O", "-", None)
]

View File

@ -10,6 +10,7 @@ from ...strings cimport hash_string
from ...structs cimport TokenC
from ...tokens.doc cimport Doc, set_children_from_heads
from ...tokens.token cimport MISSING_DEP
from ...training import split_bilu_label
from ...training.example cimport Example
from .stateclass cimport StateClass
from ._state cimport StateC, ArcC
@ -687,7 +688,7 @@ cdef class ArcEager(TransitionSystem):
return self.c[name_or_id]
name = name_or_id
if '-' in name:
move_str, label_str = name.split('-', 1)
move_str, label_str = split_bilu_label(name)
label = self.strings[label_str]
else:
move_str = name

View File

@ -13,6 +13,7 @@ from ...typedefs cimport weight_t, attr_t
from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE
from ...structs cimport TokenC, SpanC
from ...training import split_bilu_label
from ...training.example cimport Example
from .stateclass cimport StateClass
from ._state cimport StateC
@ -182,7 +183,7 @@ cdef class BiluoPushDown(TransitionSystem):
if name == '-' or name == '' or name is None:
return Transition(clas=0, move=MISSING, label=0, score=0)
elif '-' in name:
move_str, label_str = name.split('-', 1)
move_str, label_str = split_bilu_label(name)
# Deprecated, hacky way to denote 'not this entity'
if label_str.startswith('!'):
raise ValueError(Errors.E869.format(label=name))

View File

@ -12,6 +12,7 @@ from ..language import Language
from ._parser_internals import nonproj
from ._parser_internals.nonproj import DELIMITER
from ..scorer import Scorer
from ..training import remove_bilu_prefix
from ..util import registry
@ -314,7 +315,7 @@ cdef class DependencyParser(Parser):
# Get the labels from the model by looking at the available moves
for move in self.move_names:
if "-" in move:
label = move.split("-")[1]
label = remove_bilu_prefix(move)
if DELIMITER in label:
label = label.split(DELIMITER)[1]
labels.add(label)

View File

@ -6,10 +6,10 @@ from thinc.api import Model, Config
from ._parser_internals.transition_system import TransitionSystem
from .transition_parser cimport Parser
from ._parser_internals.ner cimport BiluoPushDown
from ..language import Language
from ..scorer import get_ner_prf, PRFScore
from ..util import registry
from ..training import remove_bilu_prefix
default_model_config = """
@ -242,7 +242,7 @@ cdef class EntityRecognizer(Parser):
def labels(self):
# Get the labels from the model by looking at the available moves, e.g.
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
labels = set(move.split("-")[1] for move in self.move_names
labels = set(remove_bilu_prefix(move) for move in self.move_names
if move[0] in ("B", "I", "L", "U"))
return tuple(sorted(labels))

View File

@ -10,7 +10,7 @@ from spacy.lang.it import Italian
from spacy.language import Language
from spacy.lookups import Lookups
from spacy.pipeline._parser_internals.ner import BiluoPushDown
from spacy.training import Example, iob_to_biluo
from spacy.training import Example, iob_to_biluo, split_bilu_label
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
import logging
@ -110,6 +110,9 @@ def test_issue2385():
# maintain support for iob2 format
tags3 = ("B-PERSON", "I-PERSON", "B-PERSON")
assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"]
# ensure it works with hyphens in the name
tags4 = ("B-MULTI-PERSON", "I-MULTI-PERSON", "B-MULTI-PERSON")
assert iob_to_biluo(tags4) == ["B-MULTI-PERSON", "L-MULTI-PERSON", "U-MULTI-PERSON"]
@pytest.mark.issue(2800)
@ -154,6 +157,19 @@ def test_issue3209():
assert ner2.move_names == move_names
def test_labels_from_BILUO():
"""Test that labels are inferred correctly when there's a - in label.
"""
nlp = English()
ner = nlp.add_pipe("ner")
ner.add_label("LARGE-ANIMAL")
nlp.initialize()
move_names = ["O", "B-LARGE-ANIMAL", "I-LARGE-ANIMAL", "L-LARGE-ANIMAL", "U-LARGE-ANIMAL"]
labels = {"LARGE-ANIMAL"}
assert ner.move_names == move_names
assert set(ner.labels) == labels
@pytest.mark.issue(4267)
def test_issue4267():
"""Test that running an entity_ruler after ner gives consistent results"""
@ -298,7 +314,7 @@ def test_oracle_moves_missing_B(en_vocab):
elif tag == "O":
moves.add_action(move_types.index("O"), "")
else:
action, label = tag.split("-")
action, label = split_bilu_label(tag)
moves.add_action(move_types.index("B"), label)
moves.add_action(move_types.index("I"), label)
moves.add_action(move_types.index("L"), label)
@ -324,7 +340,7 @@ def test_oracle_moves_whitespace(en_vocab):
elif tag == "O":
moves.add_action(move_types.index("O"), "")
else:
action, label = tag.split("-")
action, label = split_bilu_label(tag)
moves.add_action(move_types.index(action), label)
moves.get_oracle_sequence(example)

View File

@ -5,6 +5,7 @@ import srsly
from spacy.tokens import Doc
from spacy.vocab import Vocab
from spacy.util import make_tempdir # noqa: F401
from spacy.training import split_bilu_label
from thinc.api import get_current_ops
@ -40,7 +41,7 @@ def apply_transition_sequence(parser, doc, sequence):
desired state."""
for action_name in sequence:
if "-" in action_name:
move, label = action_name.split("-")
move, label = split_bilu_label(action_name)
parser.add_label(label)
with parser.step_through(doc) as stepwise:
for transition in sequence:

View File

@ -5,6 +5,7 @@ from .augment import dont_augment, orth_variants_augmenter # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
from .iob_utils import split_bilu_label, remove_bilu_prefix # noqa: F401
from .gold_io import docs_to_json, read_json_file # noqa: F401
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
from .loggers import console_logger # noqa: F401

View File

@ -3,10 +3,10 @@ from typing import Optional
import random
import itertools
from functools import partial
from pydantic import BaseModel, StrictStr
from ..util import registry
from .example import Example
from .iob_utils import split_bilu_label
if TYPE_CHECKING:
from ..language import Language # noqa: F401
@ -278,10 +278,8 @@ def make_whitespace_variant(
ent_prev = doc_dict["entities"][position - 1]
ent_next = doc_dict["entities"][position]
if "-" in ent_prev and "-" in ent_next:
ent_iob_prev = ent_prev.split("-")[0]
ent_type_prev = ent_prev.split("-", 1)[1]
ent_iob_next = ent_next.split("-")[0]
ent_type_next = ent_next.split("-", 1)[1]
ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev)
ent_iob_next, ent_type_next = split_bilu_label(ent_next)
if (
ent_iob_prev in ("B", "I")
and ent_iob_next in ("I", "L")

View File

@ -9,7 +9,7 @@ from ..tokens.span import Span
from ..attrs import IDS
from .alignment import Alignment
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
from .iob_utils import biluo_tags_to_spans
from .iob_utils import biluo_tags_to_spans, remove_bilu_prefix
from ..errors import Errors, Warnings
from ..pipeline._parser_internals import nonproj
from ..tokens.token cimport MISSING_DEP
@ -519,7 +519,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
else:
ent_iobs.append(iob_tag.split("-")[0])
if iob_tag.startswith("I") or iob_tag.startswith("B"):
ent_types.append(iob_tag.split("-", 1)[1])
ent_types.append(remove_bilu_prefix(iob_tag))
else:
ent_types.append("")
return ent_iobs, ent_types

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Tuple, Iterable, Union, Iterator
from typing import List, Dict, Tuple, Iterable, Union, Iterator, cast
import warnings
from ..errors import Errors, Warnings
@ -218,6 +218,14 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
return entities
def split_bilu_label(label: str) -> Tuple[str, str]:
return cast(Tuple[str, str], label.split("-", 1))
def remove_bilu_prefix(label: str) -> str:
return label.split("-", 1)[1]
# Fallbacks to make backwards-compat easier
offsets_from_biluo_tags = biluo_tags_to_offsets
spans_from_biluo_tags = biluo_tags_to_spans