mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
💫 Rule-based NER component (#2513)
* Add helper function for reading in JSONL * Add rule-based NER component * Fix whitespace * Add component to factories * Add tests * Add option to disable indent on json_dumps compat Otherwise, reading JSONL back in line by line won't work * Fix error code
This commit is contained in:
parent
d84b13e02c
commit
e7b075565d
|
@ -54,7 +54,7 @@ if is_python2:
|
|||
unicode_ = unicode # noqa: F821
|
||||
basestring_ = basestring # noqa: F821
|
||||
input_ = raw_input # noqa: F821
|
||||
json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False).decode('utf8')
|
||||
json_dumps = lambda data, indent=2: ujson.dumps(data, indent=indent, escape_forward_slashes=False).decode('utf8')
|
||||
path2str = lambda path: str(path).decode('utf8')
|
||||
|
||||
elif is_python3:
|
||||
|
@ -62,7 +62,7 @@ elif is_python3:
|
|||
unicode_ = str
|
||||
basestring_ = str
|
||||
input_ = input
|
||||
json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False)
|
||||
json_dumps = lambda data, indent=2: ujson.dumps(data, indent=indent, escape_forward_slashes=False)
|
||||
path2str = lambda path: str(path)
|
||||
|
||||
|
||||
|
|
|
@ -259,6 +259,8 @@ class Errors(object):
|
|||
"error. Are you writing to a default function argument?")
|
||||
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
|
||||
"Span objects, or dicts if set to manual=True.")
|
||||
E097 = ("Invalid pattern: expected token pattern (list of dicts) or "
|
||||
"phrase pattern (string) but got:\n{pattern}")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -18,6 +18,7 @@ from .lemmatizer import Lemmatizer
|
|||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
|
||||
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter
|
||||
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
||||
from .pipeline import EntityRuler
|
||||
from .compat import json_dumps, izip, basestring_
|
||||
from .gold import GoldParse
|
||||
from .scorer import Scorer
|
||||
|
@ -111,6 +112,7 @@ class Language(object):
|
|||
'merge_noun_chunks': lambda nlp, **cfg: merge_noun_chunks,
|
||||
'merge_entities': lambda nlp, **cfg: merge_entities,
|
||||
'merge_subtokens': lambda nlp, **cfg: merge_subtokens,
|
||||
'entity_ruler': lambda nlp, **cfg: EntityRuler(nlp, **cfg)
|
||||
}
|
||||
|
||||
def __init__(self, vocab=True, make_doc=True, max_length=10**6, meta={}, **kwargs):
|
||||
|
|
|
@ -6,7 +6,7 @@ from __future__ import unicode_literals
|
|||
import numpy
|
||||
cimport numpy as np
|
||||
import cytoolz
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, defaultdict
|
||||
import ujson
|
||||
|
||||
from .util import msgpack
|
||||
|
@ -29,12 +29,15 @@ from .syntax import nonproj
|
|||
from .compat import json_dumps
|
||||
from .matcher import Matcher
|
||||
|
||||
from .matcher import Matcher, PhraseMatcher
|
||||
from .tokens.span import Span
|
||||
from .attrs import POS
|
||||
from .parts_of_speech import X
|
||||
from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
|
||||
from ._ml import link_vectors_to_models, zero_init, flatten
|
||||
from ._ml import create_default_optimizer
|
||||
from .errors import Errors, TempErrors
|
||||
from .compat import json_dumps, basestring_
|
||||
from . import util
|
||||
|
||||
|
||||
|
@ -110,7 +113,165 @@ def merge_subtokens(doc, label='subtok'):
|
|||
for start_char, end_char in offsets:
|
||||
doc.merge(start_char, end_char)
|
||||
return doc
|
||||
|
||||
|
||||
|
||||
class EntityRuler(object):
|
||||
name = 'entity_ruler'
|
||||
|
||||
def __init__(self, nlp, **cfg):
|
||||
"""Initialise the entitiy ruler. If patterns are supplied here, they
|
||||
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||
key. A pattern can either be a token pattern (list) or a phrase pattern
|
||||
(string). For example: `{'label': 'ORG', 'pattern': 'Apple'}`.
|
||||
|
||||
nlp (Language): The shared nlp object to pass the vocab to the matchers
|
||||
and process phrase patterns.
|
||||
patterns (iterable): Optional patterns to load in.
|
||||
overwrite_ents (bool): If existing entities are present, e.g. entities
|
||||
added by the model, overwrite them by matches if necessary.
|
||||
**cfg: Other config parameters. If pipeline component is loaded as part
|
||||
of a model pipeline, this will include all keyword arguments passed
|
||||
to `spacy.load`.
|
||||
RETURNS (EntityRuler): The newly constructed object.
|
||||
"""
|
||||
self.nlp = nlp
|
||||
self.overwrite = cfg.get('overwrite_ents', False)
|
||||
self.token_patterns = defaultdict(list)
|
||||
self.phrase_patterns = defaultdict(list)
|
||||
self.matcher = Matcher(nlp.vocab)
|
||||
self.phrase_matcher = PhraseMatcher(nlp.vocab)
|
||||
patterns = cfg.get('patterns')
|
||||
if patterns is not None:
|
||||
self.add_patterns(patterns)
|
||||
|
||||
def __len__(self):
|
||||
"""The number of all patterns added to the entity ruler."""
|
||||
n_token_patterns = sum(len(p) for p in self.token_patterns.values())
|
||||
n_phrase_patterns = sum(len(p) for p in self.phrase_patterns.values())
|
||||
return n_token_patterns + n_phrase_patterns
|
||||
|
||||
def __contains__(self, label):
|
||||
"""Whether a label is present in the patterns."""
|
||||
return label in self.token_patterns or label in self.phrase_patterns
|
||||
|
||||
def __call__(self, doc):
|
||||
"""Find matches in document and add them as entities.
|
||||
|
||||
doc (Doc): The Doc object in the pipeline.
|
||||
RETURNS (Doc): The Doc with added entities, if available.
|
||||
"""
|
||||
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
|
||||
matches = set([(m_id, start, end) for m_id, start, end in matches
|
||||
if start != end])
|
||||
get_sort_key = lambda m: (m[2] - m[1], m[1])
|
||||
matches = sorted(matches, key=get_sort_key, reverse=True)
|
||||
entities = list(doc.ents)
|
||||
new_entities = []
|
||||
seen_tokens = set()
|
||||
for match_id, start, end in matches:
|
||||
if any(t.ent_type for t in doc[start:end]) and not self.overwrite:
|
||||
continue
|
||||
# check for end - 1 here because boundaries are inclusive
|
||||
if start not in seen_tokens and end - 1 not in seen_tokens:
|
||||
new_entities.append(Span(doc, start, end, label=match_id))
|
||||
entities = [e for e in entities
|
||||
if not (e.start < end and e.end > start)]
|
||||
seen_tokens.update(range(start, end))
|
||||
doc.ents = entities + new_entities
|
||||
return doc
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
"""All labels present in the match patterns.
|
||||
|
||||
RETURNS (set): The string labels.
|
||||
"""
|
||||
all_labels = set(self.token_patterns.keys())
|
||||
all_labels.update(self.phrase_patterns.keys())
|
||||
return all_labels
|
||||
|
||||
@property
|
||||
def patterns(self):
|
||||
"""Get all patterns that were added to the entity ruler.
|
||||
|
||||
RETURNS (list): The original patterns, one dictionary per pattern.
|
||||
"""
|
||||
all_patterns = []
|
||||
for label, patterns in self.token_patterns.items():
|
||||
for pattern in patterns:
|
||||
all_patterns.append({'label': label, 'pattern': pattern})
|
||||
for label, patterns in self.phrase_patterns.items():
|
||||
for pattern in patterns:
|
||||
all_patterns.append({'label': label, 'pattern': pattern.text})
|
||||
return all_patterns
|
||||
|
||||
def add_patterns(self, patterns):
|
||||
"""Add patterns to the entitiy ruler. A pattern can either be a token
|
||||
pattern (list of dicts) or a phrase pattern (string). For example:
|
||||
{'label': 'ORG', 'pattern': 'Apple'}
|
||||
{'label': 'GPE', 'pattern': [{'lower': 'san'}, {'lower': 'francisco'}]}
|
||||
|
||||
patterns (list): The patterns to add.
|
||||
"""
|
||||
for entry in patterns:
|
||||
label = entry['label']
|
||||
pattern = entry['pattern']
|
||||
if isinstance(pattern, basestring_):
|
||||
self.phrase_patterns[label].append(self.nlp(pattern))
|
||||
elif isinstance(pattern, list):
|
||||
self.token_patterns[label].append(pattern)
|
||||
else:
|
||||
raise ValueError(Errors.E097.format(pattern=pattern))
|
||||
for label, patterns in self.token_patterns.items():
|
||||
self.matcher.add(label, None, *patterns)
|
||||
for label, patterns in self.phrase_patterns.items():
|
||||
self.phrase_matcher.add(label, None, *patterns)
|
||||
|
||||
def from_bytes(self, patterns_bytes, **kwargs):
|
||||
"""Load the entity ruler from a bytestring.
|
||||
|
||||
patterns_bytes (bytes): The bytestring to load.
|
||||
**kwargs: Other config paramters, mostly for consistency.
|
||||
RETURNS (EntityRuler): The loaded entity ruler.
|
||||
"""
|
||||
patterns = msgpack.loads(patterns_bytes)
|
||||
self.add_patterns(patterns)
|
||||
return self
|
||||
|
||||
def to_bytes(self, **kwargs):
|
||||
"""Serialize the entity ruler patterns to a bytestring.
|
||||
|
||||
RETURNS (bytes): The serialized patterns.
|
||||
"""
|
||||
return msgpack.dumps(self.patterns)
|
||||
|
||||
def from_disk(self, path, **kwargs):
|
||||
"""Load the entity ruler from a file. Expects a file containing
|
||||
newline-delimited JSON (JSONL) with one entry per line.
|
||||
|
||||
path (unicode / Path): The JSONL file to load.
|
||||
**kwargs: Other config paramters, mostly for consistency.
|
||||
RETURNS (EntityRuler): The loaded entity ruler.
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
path = path.with_suffix('.jsonl')
|
||||
patterns = util.read_jsonl(path)
|
||||
self.add_patterns(patterns)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, **kwargs):
|
||||
"""Save the entity ruler patterns to a directory. The patterns will be
|
||||
saved as newline-delimited JSON (JSONL).
|
||||
|
||||
path (unicode / Path): The JSONL file to load.
|
||||
**kwargs: Other config paramters, mostly for consistency.
|
||||
RETURNS (EntityRuler): The loaded entity ruler.
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
path = path.with_suffix('.jsonl')
|
||||
data = [json_dumps(line, indent=0) for line in self.patterns]
|
||||
path.open('w').write('\n'.join(data))
|
||||
|
||||
|
||||
class Pipe(object):
|
||||
"""This class is not instantiated directly. Components inherit from it, and
|
||||
|
@ -389,7 +550,7 @@ class Tensorizer(Pipe):
|
|||
vectors = self.model.ops.xp.vstack([w.vector for w in doc])
|
||||
target.append(vectors)
|
||||
target = self.model.ops.xp.vstack(target)
|
||||
d_scores = (prediction - target)
|
||||
d_scores = (prediction - target)
|
||||
loss = (d_scores**2).sum()
|
||||
return loss, d_scores
|
||||
|
||||
|
|
89
spacy/tests/pipeline/test_entity_ruler.py
Normal file
89
spacy/tests/pipeline/test_entity_ruler.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
from ...tokens import Span
|
||||
from ...language import Language
|
||||
from ...pipeline import EntityRuler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return Language()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patterns():
|
||||
return [
|
||||
{"label": "HELLO", "pattern": "hello world"},
|
||||
{"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
|
||||
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
||||
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]}
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def add_ent():
|
||||
def add_ent_component(doc):
|
||||
doc.ents = [Span(doc, 0, 3, label=doc.vocab.strings['ORG'])]
|
||||
return doc
|
||||
return add_ent_component
|
||||
|
||||
|
||||
def test_entity_ruler_init(nlp, patterns):
|
||||
ruler = EntityRuler(nlp, patterns=patterns)
|
||||
assert len(ruler) == len(patterns)
|
||||
assert len(ruler.labels) == 3
|
||||
assert 'HELLO' in ruler
|
||||
assert 'BYE' in ruler
|
||||
nlp.add_pipe(ruler)
|
||||
doc = nlp("hello world bye bye")
|
||||
assert len(doc.ents) == 2
|
||||
assert doc.ents[0].label_ == 'HELLO'
|
||||
assert doc.ents[1].label_ == 'BYE'
|
||||
|
||||
|
||||
def test_entity_ruler_existing(nlp, patterns, add_ent):
|
||||
ruler = EntityRuler(nlp, patterns=patterns)
|
||||
nlp.add_pipe(add_ent)
|
||||
nlp.add_pipe(ruler)
|
||||
doc = nlp("OH HELLO WORLD bye bye")
|
||||
assert len(doc.ents) == 2
|
||||
assert doc.ents[0].label_ == 'ORG'
|
||||
assert doc.ents[1].label_ == 'BYE'
|
||||
|
||||
|
||||
def test_entity_ruler_existing_overwrite(nlp, patterns, add_ent):
|
||||
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||
nlp.add_pipe(add_ent)
|
||||
nlp.add_pipe(ruler)
|
||||
doc = nlp("OH HELLO WORLD bye bye")
|
||||
assert len(doc.ents) == 2
|
||||
assert doc.ents[0].label_ == 'HELLO'
|
||||
assert doc.ents[0].text == 'HELLO'
|
||||
assert doc.ents[1].label_ == 'BYE'
|
||||
|
||||
|
||||
def test_entity_ruler_existing_complex(nlp, patterns, add_ent):
|
||||
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||
nlp.add_pipe(add_ent)
|
||||
nlp.add_pipe(ruler)
|
||||
doc = nlp("foo foo bye bye")
|
||||
assert len(doc.ents) == 2
|
||||
assert doc.ents[0].label_ == 'COMPLEX'
|
||||
assert doc.ents[1].label_ == 'BYE'
|
||||
assert len(doc.ents[0]) == 2
|
||||
assert len(doc.ents[1]) == 2
|
||||
|
||||
|
||||
def test_entity_ruler_serialize_bytes(nlp, patterns):
|
||||
ruler = EntityRuler(nlp, patterns=patterns)
|
||||
assert len(ruler) == len(patterns)
|
||||
assert len(ruler.labels) == 3
|
||||
ruler_bytes = ruler.to_bytes()
|
||||
new_ruler = EntityRuler(nlp)
|
||||
assert len(new_ruler) == 0
|
||||
assert len(new_ruler.labels) == 0
|
||||
new_ruler = new_ruler.from_bytes(ruler_bytes)
|
||||
assert len(ruler) == len(patterns)
|
||||
assert len(ruler.labels) == 3
|
|
@ -507,6 +507,20 @@ def read_json(location):
|
|||
return ujson.load(f)
|
||||
|
||||
|
||||
def read_jsonl(file_path):
|
||||
"""Read a .jsonl file and yield its contents line by line.
|
||||
|
||||
file_path (unicode / Path): The file path.
|
||||
YIELDS: The loaded JSON contents of each line.
|
||||
"""
|
||||
with Path(file_path).open('r', encoding='utf8') as f:
|
||||
for line in f:
|
||||
try: # hack to handle broken jsonl
|
||||
yield ujson.loads(line.strip())
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
|
||||
def get_raw_input(description, default=False):
|
||||
"""Get user input from the command line via raw_input / input.
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user