mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 09:23:12 +03:00
💫 Rule-based NER component (#2513)
* Add helper function for reading in JSONL * Add rule-based NER component * Fix whitespace * Add component to factories * Add tests * Add option to disable indent on json_dumps compat Otherwise, reading JSONL back in line by line won't work * Fix error code
This commit is contained in:
parent
d84b13e02c
commit
e7b075565d
|
@ -54,7 +54,7 @@ if is_python2:
|
||||||
unicode_ = unicode # noqa: F821
|
unicode_ = unicode # noqa: F821
|
||||||
basestring_ = basestring # noqa: F821
|
basestring_ = basestring # noqa: F821
|
||||||
input_ = raw_input # noqa: F821
|
input_ = raw_input # noqa: F821
|
||||||
json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False).decode('utf8')
|
json_dumps = lambda data, indent=2: ujson.dumps(data, indent=indent, escape_forward_slashes=False).decode('utf8')
|
||||||
path2str = lambda path: str(path).decode('utf8')
|
path2str = lambda path: str(path).decode('utf8')
|
||||||
|
|
||||||
elif is_python3:
|
elif is_python3:
|
||||||
|
@ -62,7 +62,7 @@ elif is_python3:
|
||||||
unicode_ = str
|
unicode_ = str
|
||||||
basestring_ = str
|
basestring_ = str
|
||||||
input_ = input
|
input_ = input
|
||||||
json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False)
|
json_dumps = lambda data, indent=2: ujson.dumps(data, indent=indent, escape_forward_slashes=False)
|
||||||
path2str = lambda path: str(path)
|
path2str = lambda path: str(path)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -259,6 +259,8 @@ class Errors(object):
|
||||||
"error. Are you writing to a default function argument?")
|
"error. Are you writing to a default function argument?")
|
||||||
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
|
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
|
||||||
"Span objects, or dicts if set to manual=True.")
|
"Span objects, or dicts if set to manual=True.")
|
||||||
|
E097 = ("Invalid pattern: expected token pattern (list of dicts) or "
|
||||||
|
"phrase pattern (string) but got:\n{pattern}")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -18,6 +18,7 @@ from .lemmatizer import Lemmatizer
|
||||||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
|
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
|
||||||
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter
|
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter
|
||||||
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
||||||
|
from .pipeline import EntityRuler
|
||||||
from .compat import json_dumps, izip, basestring_
|
from .compat import json_dumps, izip, basestring_
|
||||||
from .gold import GoldParse
|
from .gold import GoldParse
|
||||||
from .scorer import Scorer
|
from .scorer import Scorer
|
||||||
|
@ -111,6 +112,7 @@ class Language(object):
|
||||||
'merge_noun_chunks': lambda nlp, **cfg: merge_noun_chunks,
|
'merge_noun_chunks': lambda nlp, **cfg: merge_noun_chunks,
|
||||||
'merge_entities': lambda nlp, **cfg: merge_entities,
|
'merge_entities': lambda nlp, **cfg: merge_entities,
|
||||||
'merge_subtokens': lambda nlp, **cfg: merge_subtokens,
|
'merge_subtokens': lambda nlp, **cfg: merge_subtokens,
|
||||||
|
'entity_ruler': lambda nlp, **cfg: EntityRuler(nlp, **cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, vocab=True, make_doc=True, max_length=10**6, meta={}, **kwargs):
|
def __init__(self, vocab=True, make_doc=True, max_length=10**6, meta={}, **kwargs):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from __future__ import unicode_literals
|
||||||
import numpy
|
import numpy
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
import cytoolz
|
import cytoolz
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict, defaultdict
|
||||||
import ujson
|
import ujson
|
||||||
|
|
||||||
from .util import msgpack
|
from .util import msgpack
|
||||||
|
@ -29,12 +29,15 @@ from .syntax import nonproj
|
||||||
from .compat import json_dumps
|
from .compat import json_dumps
|
||||||
from .matcher import Matcher
|
from .matcher import Matcher
|
||||||
|
|
||||||
|
from .matcher import Matcher, PhraseMatcher
|
||||||
|
from .tokens.span import Span
|
||||||
from .attrs import POS
|
from .attrs import POS
|
||||||
from .parts_of_speech import X
|
from .parts_of_speech import X
|
||||||
from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
|
from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
|
||||||
from ._ml import link_vectors_to_models, zero_init, flatten
|
from ._ml import link_vectors_to_models, zero_init, flatten
|
||||||
from ._ml import create_default_optimizer
|
from ._ml import create_default_optimizer
|
||||||
from .errors import Errors, TempErrors
|
from .errors import Errors, TempErrors
|
||||||
|
from .compat import json_dumps, basestring_
|
||||||
from . import util
|
from . import util
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,6 +115,164 @@ def merge_subtokens(doc, label='subtok'):
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
class EntityRuler(object):
|
||||||
|
name = 'entity_ruler'
|
||||||
|
|
||||||
|
def __init__(self, nlp, **cfg):
|
||||||
|
"""Initialise the entitiy ruler. If patterns are supplied here, they
|
||||||
|
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||||
|
key. A pattern can either be a token pattern (list) or a phrase pattern
|
||||||
|
(string). For example: `{'label': 'ORG', 'pattern': 'Apple'}`.
|
||||||
|
|
||||||
|
nlp (Language): The shared nlp object to pass the vocab to the matchers
|
||||||
|
and process phrase patterns.
|
||||||
|
patterns (iterable): Optional patterns to load in.
|
||||||
|
overwrite_ents (bool): If existing entities are present, e.g. entities
|
||||||
|
added by the model, overwrite them by matches if necessary.
|
||||||
|
**cfg: Other config parameters. If pipeline component is loaded as part
|
||||||
|
of a model pipeline, this will include all keyword arguments passed
|
||||||
|
to `spacy.load`.
|
||||||
|
RETURNS (EntityRuler): The newly constructed object.
|
||||||
|
"""
|
||||||
|
self.nlp = nlp
|
||||||
|
self.overwrite = cfg.get('overwrite_ents', False)
|
||||||
|
self.token_patterns = defaultdict(list)
|
||||||
|
self.phrase_patterns = defaultdict(list)
|
||||||
|
self.matcher = Matcher(nlp.vocab)
|
||||||
|
self.phrase_matcher = PhraseMatcher(nlp.vocab)
|
||||||
|
patterns = cfg.get('patterns')
|
||||||
|
if patterns is not None:
|
||||||
|
self.add_patterns(patterns)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""The number of all patterns added to the entity ruler."""
|
||||||
|
n_token_patterns = sum(len(p) for p in self.token_patterns.values())
|
||||||
|
n_phrase_patterns = sum(len(p) for p in self.phrase_patterns.values())
|
||||||
|
return n_token_patterns + n_phrase_patterns
|
||||||
|
|
||||||
|
def __contains__(self, label):
|
||||||
|
"""Whether a label is present in the patterns."""
|
||||||
|
return label in self.token_patterns or label in self.phrase_patterns
|
||||||
|
|
||||||
|
def __call__(self, doc):
|
||||||
|
"""Find matches in document and add them as entities.
|
||||||
|
|
||||||
|
doc (Doc): The Doc object in the pipeline.
|
||||||
|
RETURNS (Doc): The Doc with added entities, if available.
|
||||||
|
"""
|
||||||
|
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
|
||||||
|
matches = set([(m_id, start, end) for m_id, start, end in matches
|
||||||
|
if start != end])
|
||||||
|
get_sort_key = lambda m: (m[2] - m[1], m[1])
|
||||||
|
matches = sorted(matches, key=get_sort_key, reverse=True)
|
||||||
|
entities = list(doc.ents)
|
||||||
|
new_entities = []
|
||||||
|
seen_tokens = set()
|
||||||
|
for match_id, start, end in matches:
|
||||||
|
if any(t.ent_type for t in doc[start:end]) and not self.overwrite:
|
||||||
|
continue
|
||||||
|
# check for end - 1 here because boundaries are inclusive
|
||||||
|
if start not in seen_tokens and end - 1 not in seen_tokens:
|
||||||
|
new_entities.append(Span(doc, start, end, label=match_id))
|
||||||
|
entities = [e for e in entities
|
||||||
|
if not (e.start < end and e.end > start)]
|
||||||
|
seen_tokens.update(range(start, end))
|
||||||
|
doc.ents = entities + new_entities
|
||||||
|
return doc
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
"""All labels present in the match patterns.
|
||||||
|
|
||||||
|
RETURNS (set): The string labels.
|
||||||
|
"""
|
||||||
|
all_labels = set(self.token_patterns.keys())
|
||||||
|
all_labels.update(self.phrase_patterns.keys())
|
||||||
|
return all_labels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def patterns(self):
|
||||||
|
"""Get all patterns that were added to the entity ruler.
|
||||||
|
|
||||||
|
RETURNS (list): The original patterns, one dictionary per pattern.
|
||||||
|
"""
|
||||||
|
all_patterns = []
|
||||||
|
for label, patterns in self.token_patterns.items():
|
||||||
|
for pattern in patterns:
|
||||||
|
all_patterns.append({'label': label, 'pattern': pattern})
|
||||||
|
for label, patterns in self.phrase_patterns.items():
|
||||||
|
for pattern in patterns:
|
||||||
|
all_patterns.append({'label': label, 'pattern': pattern.text})
|
||||||
|
return all_patterns
|
||||||
|
|
||||||
|
def add_patterns(self, patterns):
|
||||||
|
"""Add patterns to the entitiy ruler. A pattern can either be a token
|
||||||
|
pattern (list of dicts) or a phrase pattern (string). For example:
|
||||||
|
{'label': 'ORG', 'pattern': 'Apple'}
|
||||||
|
{'label': 'GPE', 'pattern': [{'lower': 'san'}, {'lower': 'francisco'}]}
|
||||||
|
|
||||||
|
patterns (list): The patterns to add.
|
||||||
|
"""
|
||||||
|
for entry in patterns:
|
||||||
|
label = entry['label']
|
||||||
|
pattern = entry['pattern']
|
||||||
|
if isinstance(pattern, basestring_):
|
||||||
|
self.phrase_patterns[label].append(self.nlp(pattern))
|
||||||
|
elif isinstance(pattern, list):
|
||||||
|
self.token_patterns[label].append(pattern)
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E097.format(pattern=pattern))
|
||||||
|
for label, patterns in self.token_patterns.items():
|
||||||
|
self.matcher.add(label, None, *patterns)
|
||||||
|
for label, patterns in self.phrase_patterns.items():
|
||||||
|
self.phrase_matcher.add(label, None, *patterns)
|
||||||
|
|
||||||
|
def from_bytes(self, patterns_bytes, **kwargs):
|
||||||
|
"""Load the entity ruler from a bytestring.
|
||||||
|
|
||||||
|
patterns_bytes (bytes): The bytestring to load.
|
||||||
|
**kwargs: Other config paramters, mostly for consistency.
|
||||||
|
RETURNS (EntityRuler): The loaded entity ruler.
|
||||||
|
"""
|
||||||
|
patterns = msgpack.loads(patterns_bytes)
|
||||||
|
self.add_patterns(patterns)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_bytes(self, **kwargs):
|
||||||
|
"""Serialize the entity ruler patterns to a bytestring.
|
||||||
|
|
||||||
|
RETURNS (bytes): The serialized patterns.
|
||||||
|
"""
|
||||||
|
return msgpack.dumps(self.patterns)
|
||||||
|
|
||||||
|
def from_disk(self, path, **kwargs):
|
||||||
|
"""Load the entity ruler from a file. Expects a file containing
|
||||||
|
newline-delimited JSON (JSONL) with one entry per line.
|
||||||
|
|
||||||
|
path (unicode / Path): The JSONL file to load.
|
||||||
|
**kwargs: Other config paramters, mostly for consistency.
|
||||||
|
RETURNS (EntityRuler): The loaded entity ruler.
|
||||||
|
"""
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
path = path.with_suffix('.jsonl')
|
||||||
|
patterns = util.read_jsonl(path)
|
||||||
|
self.add_patterns(patterns)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(self, path, **kwargs):
|
||||||
|
"""Save the entity ruler patterns to a directory. The patterns will be
|
||||||
|
saved as newline-delimited JSON (JSONL).
|
||||||
|
|
||||||
|
path (unicode / Path): The JSONL file to load.
|
||||||
|
**kwargs: Other config paramters, mostly for consistency.
|
||||||
|
RETURNS (EntityRuler): The loaded entity ruler.
|
||||||
|
"""
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
path = path.with_suffix('.jsonl')
|
||||||
|
data = [json_dumps(line, indent=0) for line in self.patterns]
|
||||||
|
path.open('w').write('\n'.join(data))
|
||||||
|
|
||||||
|
|
||||||
class Pipe(object):
|
class Pipe(object):
|
||||||
"""This class is not instantiated directly. Components inherit from it, and
|
"""This class is not instantiated directly. Components inherit from it, and
|
||||||
it defines the interface that components should follow to function as
|
it defines the interface that components should follow to function as
|
||||||
|
|
89
spacy/tests/pipeline/test_entity_ruler.py
Normal file
89
spacy/tests/pipeline/test_entity_ruler.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ...tokens import Span
|
||||||
|
from ...language import Language
|
||||||
|
from ...pipeline import EntityRuler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nlp():
|
||||||
|
return Language()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patterns():
|
||||||
|
return [
|
||||||
|
{"label": "HELLO", "pattern": "hello world"},
|
||||||
|
{"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
|
||||||
|
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
||||||
|
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]}
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def add_ent():
|
||||||
|
def add_ent_component(doc):
|
||||||
|
doc.ents = [Span(doc, 0, 3, label=doc.vocab.strings['ORG'])]
|
||||||
|
return doc
|
||||||
|
return add_ent_component
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_init(nlp, patterns):
|
||||||
|
ruler = EntityRuler(nlp, patterns=patterns)
|
||||||
|
assert len(ruler) == len(patterns)
|
||||||
|
assert len(ruler.labels) == 3
|
||||||
|
assert 'HELLO' in ruler
|
||||||
|
assert 'BYE' in ruler
|
||||||
|
nlp.add_pipe(ruler)
|
||||||
|
doc = nlp("hello world bye bye")
|
||||||
|
assert len(doc.ents) == 2
|
||||||
|
assert doc.ents[0].label_ == 'HELLO'
|
||||||
|
assert doc.ents[1].label_ == 'BYE'
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_existing(nlp, patterns, add_ent):
|
||||||
|
ruler = EntityRuler(nlp, patterns=patterns)
|
||||||
|
nlp.add_pipe(add_ent)
|
||||||
|
nlp.add_pipe(ruler)
|
||||||
|
doc = nlp("OH HELLO WORLD bye bye")
|
||||||
|
assert len(doc.ents) == 2
|
||||||
|
assert doc.ents[0].label_ == 'ORG'
|
||||||
|
assert doc.ents[1].label_ == 'BYE'
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_existing_overwrite(nlp, patterns, add_ent):
|
||||||
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||||
|
nlp.add_pipe(add_ent)
|
||||||
|
nlp.add_pipe(ruler)
|
||||||
|
doc = nlp("OH HELLO WORLD bye bye")
|
||||||
|
assert len(doc.ents) == 2
|
||||||
|
assert doc.ents[0].label_ == 'HELLO'
|
||||||
|
assert doc.ents[0].text == 'HELLO'
|
||||||
|
assert doc.ents[1].label_ == 'BYE'
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_existing_complex(nlp, patterns, add_ent):
|
||||||
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||||
|
nlp.add_pipe(add_ent)
|
||||||
|
nlp.add_pipe(ruler)
|
||||||
|
doc = nlp("foo foo bye bye")
|
||||||
|
assert len(doc.ents) == 2
|
||||||
|
assert doc.ents[0].label_ == 'COMPLEX'
|
||||||
|
assert doc.ents[1].label_ == 'BYE'
|
||||||
|
assert len(doc.ents[0]) == 2
|
||||||
|
assert len(doc.ents[1]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_serialize_bytes(nlp, patterns):
|
||||||
|
ruler = EntityRuler(nlp, patterns=patterns)
|
||||||
|
assert len(ruler) == len(patterns)
|
||||||
|
assert len(ruler.labels) == 3
|
||||||
|
ruler_bytes = ruler.to_bytes()
|
||||||
|
new_ruler = EntityRuler(nlp)
|
||||||
|
assert len(new_ruler) == 0
|
||||||
|
assert len(new_ruler.labels) == 0
|
||||||
|
new_ruler = new_ruler.from_bytes(ruler_bytes)
|
||||||
|
assert len(ruler) == len(patterns)
|
||||||
|
assert len(ruler.labels) == 3
|
|
@ -507,6 +507,20 @@ def read_json(location):
|
||||||
return ujson.load(f)
|
return ujson.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl(file_path):
|
||||||
|
"""Read a .jsonl file and yield its contents line by line.
|
||||||
|
|
||||||
|
file_path (unicode / Path): The file path.
|
||||||
|
YIELDS: The loaded JSON contents of each line.
|
||||||
|
"""
|
||||||
|
with Path(file_path).open('r', encoding='utf8') as f:
|
||||||
|
for line in f:
|
||||||
|
try: # hack to handle broken jsonl
|
||||||
|
yield ujson.loads(line.strip())
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
def get_raw_input(description, default=False):
|
def get_raw_input(description, default=False):
|
||||||
"""Get user input from the command line via raw_input / input.
|
"""Get user input from the command line via raw_input / input.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user