diff --git a/spacy/compat.py b/spacy/compat.py index d5eee8431..f54797940 100644 --- a/spacy/compat.py +++ b/spacy/compat.py @@ -54,7 +54,7 @@ if is_python2: unicode_ = unicode # noqa: F821 basestring_ = basestring # noqa: F821 input_ = raw_input # noqa: F821 - json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False).decode('utf8') + json_dumps = lambda data, indent=2: ujson.dumps(data, indent=indent, escape_forward_slashes=False).decode('utf8') path2str = lambda path: str(path).decode('utf8') elif is_python3: @@ -62,7 +62,7 @@ elif is_python3: unicode_ = str basestring_ = str input_ = input - json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False) + json_dumps = lambda data, indent=2: ujson.dumps(data, indent=indent, escape_forward_slashes=False) path2str = lambda path: str(path) diff --git a/spacy/errors.py b/spacy/errors.py index 3d3207fbc..b17597110 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -259,6 +259,8 @@ class Errors(object): "error. Are you writing to a default function argument?") E096 = ("Invalid object passed to displaCy: Can only visualize Doc or " "Span objects, or dicts if set to manual=True.") + E097 = ("Invalid pattern: expected token pattern (list of dicts) or " + "phrase pattern (string) but got:\n{pattern}") @add_codes diff --git a/spacy/language.py b/spacy/language.py index 6b0ee6361..a993f7eb3 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -18,6 +18,7 @@ from .lemmatizer import Lemmatizer from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens +from .pipeline import EntityRuler from .compat import json_dumps, izip, basestring_ from .gold import GoldParse from .scorer import Scorer @@ -111,6 +112,7 @@ class Language(object): 'merge_noun_chunks': lambda nlp, **cfg: merge_noun_chunks, 'merge_entities': lambda nlp, **cfg: merge_entities, 'merge_subtokens': lambda nlp, **cfg: merge_subtokens, + 'entity_ruler': lambda nlp, **cfg: EntityRuler(nlp, **cfg) } def __init__(self, vocab=True, make_doc=True, max_length=10**6, meta={}, **kwargs): diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index d70ae3054..e76fac3e4 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -6,7 +6,7 @@ from __future__ import unicode_literals import numpy cimport numpy as np import cytoolz -from collections import OrderedDict +from collections import OrderedDict, defaultdict import ujson from .util import msgpack @@ -29,12 +29,15 @@ from .syntax import nonproj from .compat import json_dumps from .matcher import Matcher +from .matcher import Matcher, PhraseMatcher +from .tokens.span import Span from .attrs import POS from .parts_of_speech import X from ._ml import Tok2Vec, build_text_classifier, build_tagger_model from ._ml import link_vectors_to_models, zero_init, flatten from ._ml import create_default_optimizer from .errors import Errors, TempErrors +from .compat import json_dumps, basestring_ from . import util @@ -110,7 +113,165 @@ def merge_subtokens(doc, label='subtok'): for start_char, end_char in offsets: doc.merge(start_char, end_char) return doc - + + +class EntityRuler(object): + name = 'entity_ruler' + + def __init__(self, nlp, **cfg): + """Initialise the entitiy ruler. If patterns are supplied here, they + need to be a list of dictionaries with a `"label"` and `"pattern"` + key. A pattern can either be a token pattern (list) or a phrase pattern + (string). For example: `{'label': 'ORG', 'pattern': 'Apple'}`. + + nlp (Language): The shared nlp object to pass the vocab to the matchers + and process phrase patterns. + patterns (iterable): Optional patterns to load in. + overwrite_ents (bool): If existing entities are present, e.g. entities + added by the model, overwrite them by matches if necessary. + **cfg: Other config parameters. If pipeline component is loaded as part + of a model pipeline, this will include all keyword arguments passed + to `spacy.load`. + RETURNS (EntityRuler): The newly constructed object. + """ + self.nlp = nlp + self.overwrite = cfg.get('overwrite_ents', False) + self.token_patterns = defaultdict(list) + self.phrase_patterns = defaultdict(list) + self.matcher = Matcher(nlp.vocab) + self.phrase_matcher = PhraseMatcher(nlp.vocab) + patterns = cfg.get('patterns') + if patterns is not None: + self.add_patterns(patterns) + + def __len__(self): + """The number of all patterns added to the entity ruler.""" + n_token_patterns = sum(len(p) for p in self.token_patterns.values()) + n_phrase_patterns = sum(len(p) for p in self.phrase_patterns.values()) + return n_token_patterns + n_phrase_patterns + + def __contains__(self, label): + """Whether a label is present in the patterns.""" + return label in self.token_patterns or label in self.phrase_patterns + + def __call__(self, doc): + """Find matches in document and add them as entities. + + doc (Doc): The Doc object in the pipeline. + RETURNS (Doc): The Doc with added entities, if available. + """ + matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc)) + matches = set([(m_id, start, end) for m_id, start, end in matches + if start != end]) + get_sort_key = lambda m: (m[2] - m[1], m[1]) + matches = sorted(matches, key=get_sort_key, reverse=True) + entities = list(doc.ents) + new_entities = [] + seen_tokens = set() + for match_id, start, end in matches: + if any(t.ent_type for t in doc[start:end]) and not self.overwrite: + continue + # check for end - 1 here because boundaries are inclusive + if start not in seen_tokens and end - 1 not in seen_tokens: + new_entities.append(Span(doc, start, end, label=match_id)) + entities = [e for e in entities + if not (e.start < end and e.end > start)] + seen_tokens.update(range(start, end)) + doc.ents = entities + new_entities + return doc + + @property + def labels(self): + """All labels present in the match patterns. + + RETURNS (set): The string labels. + """ + all_labels = set(self.token_patterns.keys()) + all_labels.update(self.phrase_patterns.keys()) + return all_labels + + @property + def patterns(self): + """Get all patterns that were added to the entity ruler. + + RETURNS (list): The original patterns, one dictionary per pattern. + """ + all_patterns = [] + for label, patterns in self.token_patterns.items(): + for pattern in patterns: + all_patterns.append({'label': label, 'pattern': pattern}) + for label, patterns in self.phrase_patterns.items(): + for pattern in patterns: + all_patterns.append({'label': label, 'pattern': pattern.text}) + return all_patterns + + def add_patterns(self, patterns): + """Add patterns to the entitiy ruler. A pattern can either be a token + pattern (list of dicts) or a phrase pattern (string). For example: + {'label': 'ORG', 'pattern': 'Apple'} + {'label': 'GPE', 'pattern': [{'lower': 'san'}, {'lower': 'francisco'}]} + + patterns (list): The patterns to add. + """ + for entry in patterns: + label = entry['label'] + pattern = entry['pattern'] + if isinstance(pattern, basestring_): + self.phrase_patterns[label].append(self.nlp(pattern)) + elif isinstance(pattern, list): + self.token_patterns[label].append(pattern) + else: + raise ValueError(Errors.E097.format(pattern=pattern)) + for label, patterns in self.token_patterns.items(): + self.matcher.add(label, None, *patterns) + for label, patterns in self.phrase_patterns.items(): + self.phrase_matcher.add(label, None, *patterns) + + def from_bytes(self, patterns_bytes, **kwargs): + """Load the entity ruler from a bytestring. + + patterns_bytes (bytes): The bytestring to load. + **kwargs: Other config paramters, mostly for consistency. + RETURNS (EntityRuler): The loaded entity ruler. + """ + patterns = msgpack.loads(patterns_bytes) + self.add_patterns(patterns) + return self + + def to_bytes(self, **kwargs): + """Serialize the entity ruler patterns to a bytestring. + + RETURNS (bytes): The serialized patterns. + """ + return msgpack.dumps(self.patterns) + + def from_disk(self, path, **kwargs): + """Load the entity ruler from a file. Expects a file containing + newline-delimited JSON (JSONL) with one entry per line. + + path (unicode / Path): The JSONL file to load. + **kwargs: Other config paramters, mostly for consistency. + RETURNS (EntityRuler): The loaded entity ruler. + """ + path = util.ensure_path(path) + path = path.with_suffix('.jsonl') + patterns = util.read_jsonl(path) + self.add_patterns(patterns) + return self + + def to_disk(self, path, **kwargs): + """Save the entity ruler patterns to a directory. The patterns will be + saved as newline-delimited JSON (JSONL). + + path (unicode / Path): The JSONL file to load. + **kwargs: Other config paramters, mostly for consistency. + RETURNS (EntityRuler): The loaded entity ruler. + """ + path = util.ensure_path(path) + path = path.with_suffix('.jsonl') + data = [json_dumps(line, indent=0) for line in self.patterns] + path.open('w').write('\n'.join(data)) + class Pipe(object): """This class is not instantiated directly. Components inherit from it, and @@ -389,7 +550,7 @@ class Tensorizer(Pipe): vectors = self.model.ops.xp.vstack([w.vector for w in doc]) target.append(vectors) target = self.model.ops.xp.vstack(target) - d_scores = (prediction - target) + d_scores = (prediction - target) loss = (d_scores**2).sum() return loss, d_scores diff --git a/spacy/tests/pipeline/test_entity_ruler.py b/spacy/tests/pipeline/test_entity_ruler.py new file mode 100644 index 000000000..49f6cab61 --- /dev/null +++ b/spacy/tests/pipeline/test_entity_ruler.py @@ -0,0 +1,89 @@ +# coding: utf8 +from __future__ import unicode_literals + +import pytest + +from ...tokens import Span +from ...language import Language +from ...pipeline import EntityRuler + + +@pytest.fixture +def nlp(): + return Language() + + +@pytest.fixture +def patterns(): + return [ + {"label": "HELLO", "pattern": "hello world"}, + {"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]}, + {"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]}, + {"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]} + ] + +@pytest.fixture +def add_ent(): + def add_ent_component(doc): + doc.ents = [Span(doc, 0, 3, label=doc.vocab.strings['ORG'])] + return doc + return add_ent_component + + +def test_entity_ruler_init(nlp, patterns): + ruler = EntityRuler(nlp, patterns=patterns) + assert len(ruler) == len(patterns) + assert len(ruler.labels) == 3 + assert 'HELLO' in ruler + assert 'BYE' in ruler + nlp.add_pipe(ruler) + doc = nlp("hello world bye bye") + assert len(doc.ents) == 2 + assert doc.ents[0].label_ == 'HELLO' + assert doc.ents[1].label_ == 'BYE' + + +def test_entity_ruler_existing(nlp, patterns, add_ent): + ruler = EntityRuler(nlp, patterns=patterns) + nlp.add_pipe(add_ent) + nlp.add_pipe(ruler) + doc = nlp("OH HELLO WORLD bye bye") + assert len(doc.ents) == 2 + assert doc.ents[0].label_ == 'ORG' + assert doc.ents[1].label_ == 'BYE' + + +def test_entity_ruler_existing_overwrite(nlp, patterns, add_ent): + ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True) + nlp.add_pipe(add_ent) + nlp.add_pipe(ruler) + doc = nlp("OH HELLO WORLD bye bye") + assert len(doc.ents) == 2 + assert doc.ents[0].label_ == 'HELLO' + assert doc.ents[0].text == 'HELLO' + assert doc.ents[1].label_ == 'BYE' + + +def test_entity_ruler_existing_complex(nlp, patterns, add_ent): + ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True) + nlp.add_pipe(add_ent) + nlp.add_pipe(ruler) + doc = nlp("foo foo bye bye") + assert len(doc.ents) == 2 + assert doc.ents[0].label_ == 'COMPLEX' + assert doc.ents[1].label_ == 'BYE' + assert len(doc.ents[0]) == 2 + assert len(doc.ents[1]) == 2 + + +def test_entity_ruler_serialize_bytes(nlp, patterns): + ruler = EntityRuler(nlp, patterns=patterns) + assert len(ruler) == len(patterns) + assert len(ruler.labels) == 3 + ruler_bytes = ruler.to_bytes() + new_ruler = EntityRuler(nlp) + assert len(new_ruler) == 0 + assert len(new_ruler.labels) == 0 + new_ruler = new_ruler.from_bytes(ruler_bytes) + assert len(ruler) == len(patterns) + assert len(ruler.labels) == 3 diff --git a/spacy/util.py b/spacy/util.py index fbf35950c..af1ca64f6 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -507,6 +507,20 @@ def read_json(location): return ujson.load(f) +def read_jsonl(file_path): + """Read a .jsonl file and yield its contents line by line. + + file_path (unicode / Path): The file path. + YIELDS: The loaded JSON contents of each line. + """ + with Path(file_path).open('r', encoding='utf8') as f: + for line in f: + try: # hack to handle broken jsonl + yield ujson.loads(line.strip()) + except ValueError: + continue + + def get_raw_input(description, default=False): """Get user input from the command line via raw_input / input.