mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
f146121092
* Make handling of [Pipe].labels consistent * Un-xfail passing test * Update spacy/pipeline/pipes.pyx Co-Authored-By: ines <ines@ines.io> * Update spacy/pipeline/pipes.pyx Co-Authored-By: ines <ines@ines.io> * Update spacy/tests/pipeline/test_pipe_methods.py Co-Authored-By: ines <ines@ines.io> * Update spacy/pipeline/pipes.pyx Co-Authored-By: ines <ines@ines.io> * Move error message to spacy.errors * Fix textcat labels and test * Make EntityRuler.labels return tuple as well
171 lines
6.7 KiB
Python
171 lines
6.7 KiB
Python
# coding: utf8
|
|
from __future__ import unicode_literals
|
|
|
|
from collections import defaultdict
|
|
import srsly
|
|
|
|
from ..errors import Errors
|
|
from ..compat import basestring_
|
|
from ..util import ensure_path
|
|
from ..tokens import Span
|
|
from ..matcher import Matcher, PhraseMatcher
|
|
|
|
|
|
class EntityRuler(object):
|
|
name = "entity_ruler"
|
|
|
|
def __init__(self, nlp, **cfg):
|
|
"""Initialise the entitiy ruler. If patterns are supplied here, they
|
|
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
|
key. A pattern can either be a token pattern (list) or a phrase pattern
|
|
(string). For example: `{'label': 'ORG', 'pattern': 'Apple'}`.
|
|
|
|
nlp (Language): The shared nlp object to pass the vocab to the matchers
|
|
and process phrase patterns.
|
|
patterns (iterable): Optional patterns to load in.
|
|
overwrite_ents (bool): If existing entities are present, e.g. entities
|
|
added by the model, overwrite them by matches if necessary.
|
|
**cfg: Other config parameters. If pipeline component is loaded as part
|
|
of a model pipeline, this will include all keyword arguments passed
|
|
to `spacy.load`.
|
|
RETURNS (EntityRuler): The newly constructed object.
|
|
"""
|
|
self.nlp = nlp
|
|
self.overwrite = cfg.get("overwrite_ents", False)
|
|
self.token_patterns = defaultdict(list)
|
|
self.phrase_patterns = defaultdict(list)
|
|
self.matcher = Matcher(nlp.vocab)
|
|
self.phrase_matcher = PhraseMatcher(nlp.vocab)
|
|
patterns = cfg.get("patterns")
|
|
if patterns is not None:
|
|
self.add_patterns(patterns)
|
|
|
|
def __len__(self):
|
|
"""The number of all patterns added to the entity ruler."""
|
|
n_token_patterns = sum(len(p) for p in self.token_patterns.values())
|
|
n_phrase_patterns = sum(len(p) for p in self.phrase_patterns.values())
|
|
return n_token_patterns + n_phrase_patterns
|
|
|
|
def __contains__(self, label):
|
|
"""Whether a label is present in the patterns."""
|
|
return label in self.token_patterns or label in self.phrase_patterns
|
|
|
|
def __call__(self, doc):
|
|
"""Find matches in document and add them as entities.
|
|
|
|
doc (Doc): The Doc object in the pipeline.
|
|
RETURNS (Doc): The Doc with added entities, if available.
|
|
"""
|
|
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
|
|
matches = set(
|
|
[(m_id, start, end) for m_id, start, end in matches if start != end]
|
|
)
|
|
get_sort_key = lambda m: (m[2] - m[1], m[1])
|
|
matches = sorted(matches, key=get_sort_key, reverse=True)
|
|
entities = list(doc.ents)
|
|
new_entities = []
|
|
seen_tokens = set()
|
|
for match_id, start, end in matches:
|
|
if any(t.ent_type for t in doc[start:end]) and not self.overwrite:
|
|
continue
|
|
# check for end - 1 here because boundaries are inclusive
|
|
if start not in seen_tokens and end - 1 not in seen_tokens:
|
|
new_entities.append(Span(doc, start, end, label=match_id))
|
|
entities = [
|
|
e for e in entities if not (e.start < end and e.end > start)
|
|
]
|
|
seen_tokens.update(range(start, end))
|
|
doc.ents = entities + new_entities
|
|
return doc
|
|
|
|
@property
|
|
def labels(self):
|
|
"""All labels present in the match patterns.
|
|
|
|
RETURNS (set): The string labels.
|
|
"""
|
|
all_labels = set(self.token_patterns.keys())
|
|
all_labels.update(self.phrase_patterns.keys())
|
|
return tuple(all_labels)
|
|
|
|
@property
|
|
def patterns(self):
|
|
"""Get all patterns that were added to the entity ruler.
|
|
|
|
RETURNS (list): The original patterns, one dictionary per pattern.
|
|
"""
|
|
all_patterns = []
|
|
for label, patterns in self.token_patterns.items():
|
|
for pattern in patterns:
|
|
all_patterns.append({"label": label, "pattern": pattern})
|
|
for label, patterns in self.phrase_patterns.items():
|
|
for pattern in patterns:
|
|
all_patterns.append({"label": label, "pattern": pattern.text})
|
|
return all_patterns
|
|
|
|
def add_patterns(self, patterns):
|
|
"""Add patterns to the entitiy ruler. A pattern can either be a token
|
|
pattern (list of dicts) or a phrase pattern (string). For example:
|
|
{'label': 'ORG', 'pattern': 'Apple'}
|
|
{'label': 'GPE', 'pattern': [{'lower': 'san'}, {'lower': 'francisco'}]}
|
|
|
|
patterns (list): The patterns to add.
|
|
"""
|
|
for entry in patterns:
|
|
label = entry["label"]
|
|
pattern = entry["pattern"]
|
|
if isinstance(pattern, basestring_):
|
|
self.phrase_patterns[label].append(self.nlp(pattern))
|
|
elif isinstance(pattern, list):
|
|
self.token_patterns[label].append(pattern)
|
|
else:
|
|
raise ValueError(Errors.E097.format(pattern=pattern))
|
|
for label, patterns in self.token_patterns.items():
|
|
self.matcher.add(label, None, *patterns)
|
|
for label, patterns in self.phrase_patterns.items():
|
|
self.phrase_matcher.add(label, None, *patterns)
|
|
|
|
def from_bytes(self, patterns_bytes, **kwargs):
|
|
"""Load the entity ruler from a bytestring.
|
|
|
|
patterns_bytes (bytes): The bytestring to load.
|
|
**kwargs: Other config paramters, mostly for consistency.
|
|
RETURNS (EntityRuler): The loaded entity ruler.
|
|
"""
|
|
patterns = srsly.msgpack_loads(patterns_bytes)
|
|
self.add_patterns(patterns)
|
|
return self
|
|
|
|
def to_bytes(self, **kwargs):
|
|
"""Serialize the entity ruler patterns to a bytestring.
|
|
|
|
RETURNS (bytes): The serialized patterns.
|
|
"""
|
|
return srsly.msgpack_dumps(self.patterns)
|
|
|
|
def from_disk(self, path, **kwargs):
|
|
"""Load the entity ruler from a file. Expects a file containing
|
|
newline-delimited JSON (JSONL) with one entry per line.
|
|
|
|
path (unicode / Path): The JSONL file to load.
|
|
**kwargs: Other config paramters, mostly for consistency.
|
|
RETURNS (EntityRuler): The loaded entity ruler.
|
|
"""
|
|
path = ensure_path(path)
|
|
path = path.with_suffix(".jsonl")
|
|
patterns = srsly.read_jsonl(path)
|
|
self.add_patterns(patterns)
|
|
return self
|
|
|
|
def to_disk(self, path, **kwargs):
|
|
"""Save the entity ruler patterns to a directory. The patterns will be
|
|
saved as newline-delimited JSON (JSONL).
|
|
|
|
path (unicode / Path): The JSONL file to load.
|
|
**kwargs: Other config paramters, mostly for consistency.
|
|
RETURNS (EntityRuler): The loaded entity ruler.
|
|
"""
|
|
path = ensure_path(path)
|
|
path = path.with_suffix(".jsonl")
|
|
srsly.write_jsonl(path, self.patterns)
|