mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
account for NER labels with a hyphen in the name (#10960)
* account for NER labels with a hyphen in the name * cleanup * fix docstring * add return type to helper method * shorter method and few more occurrences * user helper method across repo * fix circular import * partial revert to avoid circular import
This commit is contained in:
parent
6313787fb6
commit
eaeca5eb6a
|
@ -10,7 +10,7 @@ import math
|
||||||
|
|
||||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||||
from ._util import import_code, debug_cli
|
from ._util import import_code, debug_cli
|
||||||
from ..training import Example
|
from ..training import Example, remove_bilu_prefix
|
||||||
from ..training.initialize import get_sourced_components
|
from ..training.initialize import get_sourced_components
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaTraining
|
||||||
from ..pipeline._parser_internals import nonproj
|
from ..pipeline._parser_internals import nonproj
|
||||||
|
@ -758,9 +758,9 @@ def _compile_gold(
|
||||||
# "Illegal" whitespace entity
|
# "Illegal" whitespace entity
|
||||||
data["ws_ents"] += 1
|
data["ws_ents"] += 1
|
||||||
if label.startswith(("B-", "U-")):
|
if label.startswith(("B-", "U-")):
|
||||||
combined_label = label.split("-")[1]
|
combined_label = remove_bilu_prefix(label)
|
||||||
data["ner"][combined_label] += 1
|
data["ner"][combined_label] += 1
|
||||||
if sent_starts[i] == True and label.startswith(("I-", "L-")):
|
if sent_starts[i] and label.startswith(("I-", "L-")):
|
||||||
data["boundary_cross_ents"] += 1
|
data["boundary_cross_ents"] += 1
|
||||||
elif label == "-":
|
elif label == "-":
|
||||||
data["ner"]["-"] += 1
|
data["ner"]["-"] += 1
|
||||||
|
@ -908,7 +908,7 @@ def _get_examples_without_label(
|
||||||
for eg in data:
|
for eg in data:
|
||||||
if component == "ner":
|
if component == "ner":
|
||||||
labels = [
|
labels = [
|
||||||
label.split("-")[1]
|
remove_bilu_prefix(label)
|
||||||
for label in eg.get_aligned_ner()
|
for label in eg.get_aligned_ner()
|
||||||
if label not in ("O", "-", None)
|
if label not in ("O", "-", None)
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,6 +10,7 @@ from ...strings cimport hash_string
|
||||||
from ...structs cimport TokenC
|
from ...structs cimport TokenC
|
||||||
from ...tokens.doc cimport Doc, set_children_from_heads
|
from ...tokens.doc cimport Doc, set_children_from_heads
|
||||||
from ...tokens.token cimport MISSING_DEP
|
from ...tokens.token cimport MISSING_DEP
|
||||||
|
from ...training import split_bilu_label
|
||||||
from ...training.example cimport Example
|
from ...training.example cimport Example
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC, ArcC
|
from ._state cimport StateC, ArcC
|
||||||
|
@ -687,7 +688,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
return self.c[name_or_id]
|
return self.c[name_or_id]
|
||||||
name = name_or_id
|
name = name_or_id
|
||||||
if '-' in name:
|
if '-' in name:
|
||||||
move_str, label_str = name.split('-', 1)
|
move_str, label_str = split_bilu_label(name)
|
||||||
label = self.strings[label_str]
|
label = self.strings[label_str]
|
||||||
else:
|
else:
|
||||||
move_str = name
|
move_str = name
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ...typedefs cimport weight_t, attr_t
|
||||||
from ...lexeme cimport Lexeme
|
from ...lexeme cimport Lexeme
|
||||||
from ...attrs cimport IS_SPACE
|
from ...attrs cimport IS_SPACE
|
||||||
from ...structs cimport TokenC, SpanC
|
from ...structs cimport TokenC, SpanC
|
||||||
|
from ...training import split_bilu_label
|
||||||
from ...training.example cimport Example
|
from ...training.example cimport Example
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC
|
from ._state cimport StateC
|
||||||
|
@ -182,7 +183,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
if name == '-' or name == '' or name is None:
|
if name == '-' or name == '' or name is None:
|
||||||
return Transition(clas=0, move=MISSING, label=0, score=0)
|
return Transition(clas=0, move=MISSING, label=0, score=0)
|
||||||
elif '-' in name:
|
elif '-' in name:
|
||||||
move_str, label_str = name.split('-', 1)
|
move_str, label_str = split_bilu_label(name)
|
||||||
# Deprecated, hacky way to denote 'not this entity'
|
# Deprecated, hacky way to denote 'not this entity'
|
||||||
if label_str.startswith('!'):
|
if label_str.startswith('!'):
|
||||||
raise ValueError(Errors.E869.format(label=name))
|
raise ValueError(Errors.E869.format(label=name))
|
||||||
|
|
|
@ -12,6 +12,7 @@ from ..language import Language
|
||||||
from ._parser_internals import nonproj
|
from ._parser_internals import nonproj
|
||||||
from ._parser_internals.nonproj import DELIMITER
|
from ._parser_internals.nonproj import DELIMITER
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
|
from ..training import remove_bilu_prefix
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
|
|
||||||
|
|
||||||
|
@ -314,7 +315,7 @@ cdef class DependencyParser(Parser):
|
||||||
# Get the labels from the model by looking at the available moves
|
# Get the labels from the model by looking at the available moves
|
||||||
for move in self.move_names:
|
for move in self.move_names:
|
||||||
if "-" in move:
|
if "-" in move:
|
||||||
label = move.split("-")[1]
|
label = remove_bilu_prefix(move)
|
||||||
if DELIMITER in label:
|
if DELIMITER in label:
|
||||||
label = label.split(DELIMITER)[1]
|
label = label.split(DELIMITER)[1]
|
||||||
labels.add(label)
|
labels.add(label)
|
||||||
|
|
|
@ -6,10 +6,10 @@ from thinc.api import Model, Config
|
||||||
from ._parser_internals.transition_system import TransitionSystem
|
from ._parser_internals.transition_system import TransitionSystem
|
||||||
from .transition_parser cimport Parser
|
from .transition_parser cimport Parser
|
||||||
from ._parser_internals.ner cimport BiluoPushDown
|
from ._parser_internals.ner cimport BiluoPushDown
|
||||||
|
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..scorer import get_ner_prf, PRFScore
|
from ..scorer import get_ner_prf, PRFScore
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
|
from ..training import remove_bilu_prefix
|
||||||
|
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
|
@ -242,7 +242,7 @@ cdef class EntityRecognizer(Parser):
|
||||||
def labels(self):
|
def labels(self):
|
||||||
# Get the labels from the model by looking at the available moves, e.g.
|
# Get the labels from the model by looking at the available moves, e.g.
|
||||||
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
|
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
|
||||||
labels = set(move.split("-")[1] for move in self.move_names
|
labels = set(remove_bilu_prefix(move) for move in self.move_names
|
||||||
if move[0] in ("B", "I", "L", "U"))
|
if move[0] in ("B", "I", "L", "U"))
|
||||||
return tuple(sorted(labels))
|
return tuple(sorted(labels))
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from spacy.lang.it import Italian
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.lookups import Lookups
|
from spacy.lookups import Lookups
|
||||||
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
||||||
from spacy.training import Example, iob_to_biluo
|
from spacy.training import Example, iob_to_biluo, split_bilu_label
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
import logging
|
import logging
|
||||||
|
@ -110,6 +110,9 @@ def test_issue2385():
|
||||||
# maintain support for iob2 format
|
# maintain support for iob2 format
|
||||||
tags3 = ("B-PERSON", "I-PERSON", "B-PERSON")
|
tags3 = ("B-PERSON", "I-PERSON", "B-PERSON")
|
||||||
assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"]
|
assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"]
|
||||||
|
# ensure it works with hyphens in the name
|
||||||
|
tags4 = ("B-MULTI-PERSON", "I-MULTI-PERSON", "B-MULTI-PERSON")
|
||||||
|
assert iob_to_biluo(tags4) == ["B-MULTI-PERSON", "L-MULTI-PERSON", "U-MULTI-PERSON"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.issue(2800)
|
@pytest.mark.issue(2800)
|
||||||
|
@ -154,6 +157,19 @@ def test_issue3209():
|
||||||
assert ner2.move_names == move_names
|
assert ner2.move_names == move_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_labels_from_BILUO():
|
||||||
|
"""Test that labels are inferred correctly when there's a - in label.
|
||||||
|
"""
|
||||||
|
nlp = English()
|
||||||
|
ner = nlp.add_pipe("ner")
|
||||||
|
ner.add_label("LARGE-ANIMAL")
|
||||||
|
nlp.initialize()
|
||||||
|
move_names = ["O", "B-LARGE-ANIMAL", "I-LARGE-ANIMAL", "L-LARGE-ANIMAL", "U-LARGE-ANIMAL"]
|
||||||
|
labels = {"LARGE-ANIMAL"}
|
||||||
|
assert ner.move_names == move_names
|
||||||
|
assert set(ner.labels) == labels
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.issue(4267)
|
@pytest.mark.issue(4267)
|
||||||
def test_issue4267():
|
def test_issue4267():
|
||||||
"""Test that running an entity_ruler after ner gives consistent results"""
|
"""Test that running an entity_ruler after ner gives consistent results"""
|
||||||
|
@ -298,7 +314,7 @@ def test_oracle_moves_missing_B(en_vocab):
|
||||||
elif tag == "O":
|
elif tag == "O":
|
||||||
moves.add_action(move_types.index("O"), "")
|
moves.add_action(move_types.index("O"), "")
|
||||||
else:
|
else:
|
||||||
action, label = tag.split("-")
|
action, label = split_bilu_label(tag)
|
||||||
moves.add_action(move_types.index("B"), label)
|
moves.add_action(move_types.index("B"), label)
|
||||||
moves.add_action(move_types.index("I"), label)
|
moves.add_action(move_types.index("I"), label)
|
||||||
moves.add_action(move_types.index("L"), label)
|
moves.add_action(move_types.index("L"), label)
|
||||||
|
@ -324,7 +340,7 @@ def test_oracle_moves_whitespace(en_vocab):
|
||||||
elif tag == "O":
|
elif tag == "O":
|
||||||
moves.add_action(move_types.index("O"), "")
|
moves.add_action(move_types.index("O"), "")
|
||||||
else:
|
else:
|
||||||
action, label = tag.split("-")
|
action, label = split_bilu_label(tag)
|
||||||
moves.add_action(move_types.index(action), label)
|
moves.add_action(move_types.index(action), label)
|
||||||
moves.get_oracle_sequence(example)
|
moves.get_oracle_sequence(example)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import srsly
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.util import make_tempdir # noqa: F401
|
from spacy.util import make_tempdir # noqa: F401
|
||||||
|
from spacy.training import split_bilu_label
|
||||||
from thinc.api import get_current_ops
|
from thinc.api import get_current_ops
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ def apply_transition_sequence(parser, doc, sequence):
|
||||||
desired state."""
|
desired state."""
|
||||||
for action_name in sequence:
|
for action_name in sequence:
|
||||||
if "-" in action_name:
|
if "-" in action_name:
|
||||||
move, label = action_name.split("-")
|
move, label = split_bilu_label(action_name)
|
||||||
parser.add_label(label)
|
parser.add_label(label)
|
||||||
with parser.step_through(doc) as stepwise:
|
with parser.step_through(doc) as stepwise:
|
||||||
for transition in sequence:
|
for transition in sequence:
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .augment import dont_augment, orth_variants_augmenter # noqa: F401
|
||||||
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
|
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
|
||||||
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
|
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
|
||||||
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
|
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
|
||||||
|
from .iob_utils import split_bilu_label, remove_bilu_prefix # noqa: F401
|
||||||
from .gold_io import docs_to_json, read_json_file # noqa: F401
|
from .gold_io import docs_to_json, read_json_file # noqa: F401
|
||||||
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
|
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
|
||||||
from .loggers import console_logger # noqa: F401
|
from .loggers import console_logger # noqa: F401
|
||||||
|
|
|
@ -3,10 +3,10 @@ from typing import Optional
|
||||||
import random
|
import random
|
||||||
import itertools
|
import itertools
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pydantic import BaseModel, StrictStr
|
|
||||||
|
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from .example import Example
|
from .example import Example
|
||||||
|
from .iob_utils import split_bilu_label
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
@ -278,10 +278,8 @@ def make_whitespace_variant(
|
||||||
ent_prev = doc_dict["entities"][position - 1]
|
ent_prev = doc_dict["entities"][position - 1]
|
||||||
ent_next = doc_dict["entities"][position]
|
ent_next = doc_dict["entities"][position]
|
||||||
if "-" in ent_prev and "-" in ent_next:
|
if "-" in ent_prev and "-" in ent_next:
|
||||||
ent_iob_prev = ent_prev.split("-")[0]
|
ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev)
|
||||||
ent_type_prev = ent_prev.split("-", 1)[1]
|
ent_iob_next, ent_type_next = split_bilu_label(ent_next)
|
||||||
ent_iob_next = ent_next.split("-")[0]
|
|
||||||
ent_type_next = ent_next.split("-", 1)[1]
|
|
||||||
if (
|
if (
|
||||||
ent_iob_prev in ("B", "I")
|
ent_iob_prev in ("B", "I")
|
||||||
and ent_iob_next in ("I", "L")
|
and ent_iob_next in ("I", "L")
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ..tokens.span import Span
|
||||||
from ..attrs import IDS
|
from ..attrs import IDS
|
||||||
from .alignment import Alignment
|
from .alignment import Alignment
|
||||||
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
|
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
|
||||||
from .iob_utils import biluo_tags_to_spans
|
from .iob_utils import biluo_tags_to_spans, remove_bilu_prefix
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..pipeline._parser_internals import nonproj
|
from ..pipeline._parser_internals import nonproj
|
||||||
from ..tokens.token cimport MISSING_DEP
|
from ..tokens.token cimport MISSING_DEP
|
||||||
|
@ -519,7 +519,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
|
||||||
else:
|
else:
|
||||||
ent_iobs.append(iob_tag.split("-")[0])
|
ent_iobs.append(iob_tag.split("-")[0])
|
||||||
if iob_tag.startswith("I") or iob_tag.startswith("B"):
|
if iob_tag.startswith("I") or iob_tag.startswith("B"):
|
||||||
ent_types.append(iob_tag.split("-", 1)[1])
|
ent_types.append(remove_bilu_prefix(iob_tag))
|
||||||
else:
|
else:
|
||||||
ent_types.append("")
|
ent_types.append("")
|
||||||
return ent_iobs, ent_types
|
return ent_iobs, ent_types
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Dict, Tuple, Iterable, Union, Iterator
|
from typing import List, Dict, Tuple, Iterable, Union, Iterator, cast
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
|
@ -218,6 +218,14 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
|
|
||||||
|
def split_bilu_label(label: str) -> Tuple[str, str]:
|
||||||
|
return cast(Tuple[str, str], label.split("-", 1))
|
||||||
|
|
||||||
|
|
||||||
|
def remove_bilu_prefix(label: str) -> str:
|
||||||
|
return label.split("-", 1)[1]
|
||||||
|
|
||||||
|
|
||||||
# Fallbacks to make backwards-compat easier
|
# Fallbacks to make backwards-compat easier
|
||||||
offsets_from_biluo_tags = biluo_tags_to_offsets
|
offsets_from_biluo_tags = biluo_tags_to_offsets
|
||||||
spans_from_biluo_tags = biluo_tags_to_spans
|
spans_from_biluo_tags = biluo_tags_to_spans
|
||||||
|
|
Loading…
Reference in New Issue
Block a user