mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	💫 Rule-based NER component (#2513)
* Add helper function for reading in JSONL * Add rule-based NER component * Fix whitespace * Add component to factories * Add tests * Add option to disable indent on json_dumps compat Otherwise, reading JSONL back in line by line won't work * Fix error code
This commit is contained in:
		
							parent
							
								
									d84b13e02c
								
							
						
					
					
						commit
						e7b075565d
					
				| 
						 | 
				
			
			@ -54,7 +54,7 @@ if is_python2:
 | 
			
		|||
    unicode_ = unicode  # noqa: F821
 | 
			
		||||
    basestring_ = basestring  # noqa: F821
 | 
			
		||||
    input_ = raw_input  # noqa: F821
 | 
			
		||||
    json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False).decode('utf8')
 | 
			
		||||
    json_dumps = lambda data, indent=2: ujson.dumps(data, indent=indent, escape_forward_slashes=False).decode('utf8')
 | 
			
		||||
    path2str = lambda path: str(path).decode('utf8')
 | 
			
		||||
 | 
			
		||||
elif is_python3:
 | 
			
		||||
| 
						 | 
				
			
			@ -62,7 +62,7 @@ elif is_python3:
 | 
			
		|||
    unicode_ = str
 | 
			
		||||
    basestring_ = str
 | 
			
		||||
    input_ = input
 | 
			
		||||
    json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False)
 | 
			
		||||
    json_dumps = lambda data, indent=2: ujson.dumps(data, indent=indent, escape_forward_slashes=False)
 | 
			
		||||
    path2str = lambda path: str(path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -259,6 +259,8 @@ class Errors(object):
 | 
			
		|||
            "error. Are you writing to a default function argument?")
 | 
			
		||||
    E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
 | 
			
		||||
             "Span objects, or dicts if set to manual=True.")
 | 
			
		||||
    E097 = ("Invalid pattern: expected token pattern (list of dicts) or "
 | 
			
		||||
            "phrase pattern (string) but got:\n{pattern}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@add_codes
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,7 @@ from .lemmatizer import Lemmatizer
 | 
			
		|||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
 | 
			
		||||
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter
 | 
			
		||||
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
 | 
			
		||||
from .pipeline import EntityRuler
 | 
			
		||||
from .compat import json_dumps, izip, basestring_
 | 
			
		||||
from .gold import GoldParse
 | 
			
		||||
from .scorer import Scorer
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +112,7 @@ class Language(object):
 | 
			
		|||
        'merge_noun_chunks': lambda nlp, **cfg: merge_noun_chunks,
 | 
			
		||||
        'merge_entities': lambda nlp, **cfg: merge_entities,
 | 
			
		||||
        'merge_subtokens': lambda nlp, **cfg: merge_subtokens,
 | 
			
		||||
        'entity_ruler': lambda nlp, **cfg: EntityRuler(nlp, **cfg)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, vocab=True, make_doc=True, max_length=10**6, meta={}, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,7 +6,7 @@ from __future__ import unicode_literals
 | 
			
		|||
import numpy
 | 
			
		||||
cimport numpy as np
 | 
			
		||||
import cytoolz
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from collections import OrderedDict, defaultdict
 | 
			
		||||
import ujson
 | 
			
		||||
 | 
			
		||||
from .util import msgpack
 | 
			
		||||
| 
						 | 
				
			
			@ -29,12 +29,15 @@ from .syntax import nonproj
 | 
			
		|||
from .compat import json_dumps
 | 
			
		||||
from .matcher import Matcher
 | 
			
		||||
 | 
			
		||||
from .matcher import Matcher, PhraseMatcher
 | 
			
		||||
from .tokens.span import Span
 | 
			
		||||
from .attrs import POS
 | 
			
		||||
from .parts_of_speech import X
 | 
			
		||||
from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
 | 
			
		||||
from ._ml import link_vectors_to_models, zero_init, flatten
 | 
			
		||||
from ._ml import create_default_optimizer
 | 
			
		||||
from .errors import Errors, TempErrors
 | 
			
		||||
from .compat import json_dumps, basestring_
 | 
			
		||||
from . import util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -112,6 +115,164 @@ def merge_subtokens(doc, label='subtok'):
 | 
			
		|||
    return doc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EntityRuler(object):
 | 
			
		||||
    name = 'entity_ruler'
 | 
			
		||||
 | 
			
		||||
    def __init__(self, nlp, **cfg):
 | 
			
		||||
        """Initialise the entitiy ruler. If patterns are supplied here, they
 | 
			
		||||
        need to be a list of dictionaries with a `"label"` and `"pattern"`
 | 
			
		||||
        key. A pattern can either be a token pattern (list) or a phrase pattern
 | 
			
		||||
        (string). For example: `{'label': 'ORG', 'pattern': 'Apple'}`.
 | 
			
		||||
 | 
			
		||||
        nlp (Language): The shared nlp object to pass the vocab to the matchers
 | 
			
		||||
            and process phrase patterns.
 | 
			
		||||
        patterns (iterable): Optional patterns to load in.
 | 
			
		||||
        overwrite_ents (bool): If existing entities are present, e.g. entities
 | 
			
		||||
            added by the model, overwrite them by matches if necessary.
 | 
			
		||||
        **cfg: Other config parameters. If pipeline component is loaded as part
 | 
			
		||||
            of a model pipeline, this will include all keyword arguments passed
 | 
			
		||||
            to `spacy.load`.
 | 
			
		||||
        RETURNS (EntityRuler): The newly constructed object.
 | 
			
		||||
        """
 | 
			
		||||
        self.nlp = nlp
 | 
			
		||||
        self.overwrite = cfg.get('overwrite_ents', False)
 | 
			
		||||
        self.token_patterns = defaultdict(list)
 | 
			
		||||
        self.phrase_patterns = defaultdict(list)
 | 
			
		||||
        self.matcher = Matcher(nlp.vocab)
 | 
			
		||||
        self.phrase_matcher = PhraseMatcher(nlp.vocab)
 | 
			
		||||
        patterns = cfg.get('patterns')
 | 
			
		||||
        if patterns is not None:
 | 
			
		||||
            self.add_patterns(patterns)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        """The number of all patterns added to the entity ruler."""
 | 
			
		||||
        n_token_patterns = sum(len(p) for p in self.token_patterns.values())
 | 
			
		||||
        n_phrase_patterns = sum(len(p) for p in self.phrase_patterns.values())
 | 
			
		||||
        return n_token_patterns + n_phrase_patterns
 | 
			
		||||
 | 
			
		||||
    def __contains__(self, label):
 | 
			
		||||
        """Whether a label is present in the patterns."""
 | 
			
		||||
        return label in self.token_patterns or label in self.phrase_patterns
 | 
			
		||||
 | 
			
		||||
    def __call__(self, doc):
 | 
			
		||||
        """Find matches in document and add them as entities.
 | 
			
		||||
 | 
			
		||||
        doc (Doc): The Doc object in the pipeline.
 | 
			
		||||
        RETURNS (Doc): The Doc with added entities, if available.
 | 
			
		||||
        """
 | 
			
		||||
        matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
 | 
			
		||||
        matches = set([(m_id, start, end) for m_id, start, end in matches
 | 
			
		||||
                       if start != end])
 | 
			
		||||
        get_sort_key = lambda m: (m[2] - m[1], m[1])
 | 
			
		||||
        matches = sorted(matches, key=get_sort_key, reverse=True)
 | 
			
		||||
        entities = list(doc.ents)
 | 
			
		||||
        new_entities = []
 | 
			
		||||
        seen_tokens = set()
 | 
			
		||||
        for match_id, start, end in matches:
 | 
			
		||||
            if any(t.ent_type for t in doc[start:end]) and not self.overwrite:
 | 
			
		||||
                continue
 | 
			
		||||
            # check for end - 1 here because boundaries are inclusive
 | 
			
		||||
            if start not in seen_tokens and end - 1 not in seen_tokens:
 | 
			
		||||
                new_entities.append(Span(doc, start, end, label=match_id))
 | 
			
		||||
                entities = [e for e in entities
 | 
			
		||||
                            if not (e.start < end and e.end > start)]
 | 
			
		||||
                seen_tokens.update(range(start, end))
 | 
			
		||||
        doc.ents = entities + new_entities
 | 
			
		||||
        return doc
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def labels(self):
 | 
			
		||||
        """All labels present in the match patterns.
 | 
			
		||||
 | 
			
		||||
        RETURNS (set): The string labels.
 | 
			
		||||
        """
 | 
			
		||||
        all_labels = set(self.token_patterns.keys())
 | 
			
		||||
        all_labels.update(self.phrase_patterns.keys())
 | 
			
		||||
        return all_labels
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def patterns(self):
 | 
			
		||||
        """Get all patterns that were added to the entity ruler.
 | 
			
		||||
 | 
			
		||||
        RETURNS (list): The original patterns, one dictionary per pattern.
 | 
			
		||||
        """
 | 
			
		||||
        all_patterns = []
 | 
			
		||||
        for label, patterns in self.token_patterns.items():
 | 
			
		||||
            for pattern in patterns:
 | 
			
		||||
                all_patterns.append({'label': label, 'pattern': pattern})
 | 
			
		||||
        for label, patterns in self.phrase_patterns.items():
 | 
			
		||||
            for pattern in patterns:
 | 
			
		||||
                all_patterns.append({'label': label, 'pattern': pattern.text})
 | 
			
		||||
        return all_patterns
 | 
			
		||||
 | 
			
		||||
    def add_patterns(self, patterns):
 | 
			
		||||
        """Add patterns to the entitiy ruler. A pattern can either be a token
 | 
			
		||||
        pattern (list of dicts) or a phrase pattern (string). For example:
 | 
			
		||||
        {'label': 'ORG', 'pattern': 'Apple'}
 | 
			
		||||
        {'label': 'GPE', 'pattern': [{'lower': 'san'}, {'lower': 'francisco'}]}
 | 
			
		||||
 | 
			
		||||
        patterns (list): The patterns to add.
 | 
			
		||||
        """
 | 
			
		||||
        for entry in patterns:
 | 
			
		||||
            label = entry['label']
 | 
			
		||||
            pattern = entry['pattern']
 | 
			
		||||
            if isinstance(pattern, basestring_):
 | 
			
		||||
                self.phrase_patterns[label].append(self.nlp(pattern))
 | 
			
		||||
            elif isinstance(pattern, list):
 | 
			
		||||
                self.token_patterns[label].append(pattern)
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(Errors.E097.format(pattern=pattern))
 | 
			
		||||
        for label, patterns in self.token_patterns.items():
 | 
			
		||||
            self.matcher.add(label, None, *patterns)
 | 
			
		||||
        for label, patterns in self.phrase_patterns.items():
 | 
			
		||||
            self.phrase_matcher.add(label, None, *patterns)
 | 
			
		||||
 | 
			
		||||
    def from_bytes(self, patterns_bytes, **kwargs):
 | 
			
		||||
        """Load the entity ruler from a bytestring.
 | 
			
		||||
 | 
			
		||||
        patterns_bytes (bytes): The bytestring to load.
 | 
			
		||||
        **kwargs: Other config paramters, mostly for consistency.
 | 
			
		||||
        RETURNS (EntityRuler): The loaded entity ruler.
 | 
			
		||||
        """
 | 
			
		||||
        patterns = msgpack.loads(patterns_bytes)
 | 
			
		||||
        self.add_patterns(patterns)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def to_bytes(self, **kwargs):
 | 
			
		||||
        """Serialize the entity ruler patterns to a bytestring.
 | 
			
		||||
 | 
			
		||||
        RETURNS (bytes): The serialized patterns.
 | 
			
		||||
        """
 | 
			
		||||
        return msgpack.dumps(self.patterns)
 | 
			
		||||
 | 
			
		||||
    def from_disk(self, path, **kwargs):
 | 
			
		||||
        """Load the entity ruler from a file. Expects a file containing
 | 
			
		||||
        newline-delimited JSON (JSONL) with one entry per line.
 | 
			
		||||
 | 
			
		||||
        path (unicode / Path): The JSONL file to load.
 | 
			
		||||
        **kwargs: Other config paramters, mostly for consistency.
 | 
			
		||||
        RETURNS (EntityRuler): The loaded entity ruler.
 | 
			
		||||
        """
 | 
			
		||||
        path = util.ensure_path(path)
 | 
			
		||||
        path = path.with_suffix('.jsonl')
 | 
			
		||||
        patterns = util.read_jsonl(path)
 | 
			
		||||
        self.add_patterns(patterns)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def to_disk(self, path, **kwargs):
 | 
			
		||||
        """Save the entity ruler patterns to a directory. The patterns will be
 | 
			
		||||
        saved as newline-delimited JSON (JSONL).
 | 
			
		||||
 | 
			
		||||
        path (unicode / Path): The JSONL file to load.
 | 
			
		||||
        **kwargs: Other config paramters, mostly for consistency.
 | 
			
		||||
        RETURNS (EntityRuler): The loaded entity ruler.
 | 
			
		||||
        """
 | 
			
		||||
        path = util.ensure_path(path)
 | 
			
		||||
        path = path.with_suffix('.jsonl')
 | 
			
		||||
        data = [json_dumps(line, indent=0) for line in self.patterns]
 | 
			
		||||
        path.open('w').write('\n'.join(data))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Pipe(object):
 | 
			
		||||
    """This class is not instantiated directly. Components inherit from it, and
 | 
			
		||||
    it defines the interface that components should follow to function as
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										89
									
								
								spacy/tests/pipeline/test_entity_ruler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								spacy/tests/pipeline/test_entity_ruler.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,89 @@
 | 
			
		|||
# coding: utf8
 | 
			
		||||
from __future__ import unicode_literals
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from ...tokens import Span
 | 
			
		||||
from ...language import Language
 | 
			
		||||
from ...pipeline import EntityRuler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def nlp():
 | 
			
		||||
    return Language()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def patterns():
 | 
			
		||||
    return [
 | 
			
		||||
        {"label": "HELLO", "pattern": "hello world"},
 | 
			
		||||
        {"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
 | 
			
		||||
        {"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
 | 
			
		||||
        {"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]}
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def add_ent():
 | 
			
		||||
    def add_ent_component(doc):
 | 
			
		||||
        doc.ents = [Span(doc, 0, 3, label=doc.vocab.strings['ORG'])]
 | 
			
		||||
        return doc
 | 
			
		||||
    return add_ent_component
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_entity_ruler_init(nlp, patterns):
 | 
			
		||||
    ruler = EntityRuler(nlp, patterns=patterns)
 | 
			
		||||
    assert len(ruler) == len(patterns)
 | 
			
		||||
    assert len(ruler.labels) == 3
 | 
			
		||||
    assert 'HELLO' in ruler
 | 
			
		||||
    assert 'BYE' in ruler
 | 
			
		||||
    nlp.add_pipe(ruler)
 | 
			
		||||
    doc = nlp("hello world bye bye")
 | 
			
		||||
    assert len(doc.ents) == 2
 | 
			
		||||
    assert doc.ents[0].label_ == 'HELLO'
 | 
			
		||||
    assert doc.ents[1].label_ == 'BYE'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_entity_ruler_existing(nlp, patterns, add_ent):
 | 
			
		||||
    ruler = EntityRuler(nlp, patterns=patterns)
 | 
			
		||||
    nlp.add_pipe(add_ent)
 | 
			
		||||
    nlp.add_pipe(ruler)
 | 
			
		||||
    doc = nlp("OH HELLO WORLD bye bye")
 | 
			
		||||
    assert len(doc.ents) == 2
 | 
			
		||||
    assert doc.ents[0].label_ == 'ORG'
 | 
			
		||||
    assert doc.ents[1].label_ == 'BYE'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_entity_ruler_existing_overwrite(nlp, patterns, add_ent):
 | 
			
		||||
    ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
 | 
			
		||||
    nlp.add_pipe(add_ent)
 | 
			
		||||
    nlp.add_pipe(ruler)
 | 
			
		||||
    doc = nlp("OH HELLO WORLD bye bye")
 | 
			
		||||
    assert len(doc.ents) == 2
 | 
			
		||||
    assert doc.ents[0].label_ == 'HELLO'
 | 
			
		||||
    assert doc.ents[0].text == 'HELLO'
 | 
			
		||||
    assert doc.ents[1].label_ == 'BYE'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_entity_ruler_existing_complex(nlp, patterns, add_ent):
 | 
			
		||||
    ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
 | 
			
		||||
    nlp.add_pipe(add_ent)
 | 
			
		||||
    nlp.add_pipe(ruler)
 | 
			
		||||
    doc = nlp("foo foo bye bye")
 | 
			
		||||
    assert len(doc.ents) == 2
 | 
			
		||||
    assert doc.ents[0].label_ == 'COMPLEX'
 | 
			
		||||
    assert doc.ents[1].label_ == 'BYE'
 | 
			
		||||
    assert len(doc.ents[0]) == 2
 | 
			
		||||
    assert len(doc.ents[1]) == 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_entity_ruler_serialize_bytes(nlp, patterns):
 | 
			
		||||
    ruler = EntityRuler(nlp, patterns=patterns)
 | 
			
		||||
    assert len(ruler) == len(patterns)
 | 
			
		||||
    assert len(ruler.labels) == 3
 | 
			
		||||
    ruler_bytes = ruler.to_bytes()
 | 
			
		||||
    new_ruler = EntityRuler(nlp)
 | 
			
		||||
    assert len(new_ruler) == 0
 | 
			
		||||
    assert len(new_ruler.labels) == 0
 | 
			
		||||
    new_ruler = new_ruler.from_bytes(ruler_bytes)
 | 
			
		||||
    assert len(ruler) == len(patterns)
 | 
			
		||||
    assert len(ruler.labels) == 3
 | 
			
		||||
| 
						 | 
				
			
			@ -507,6 +507,20 @@ def read_json(location):
 | 
			
		|||
        return ujson.load(f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_jsonl(file_path):
 | 
			
		||||
    """Read a .jsonl file and yield its contents line by line.
 | 
			
		||||
 | 
			
		||||
    file_path (unicode / Path): The file path.
 | 
			
		||||
    YIELDS: The loaded JSON contents of each line.
 | 
			
		||||
    """
 | 
			
		||||
    with Path(file_path).open('r', encoding='utf8') as f:
 | 
			
		||||
        for line in f:
 | 
			
		||||
            try:  # hack to handle broken jsonl
 | 
			
		||||
                yield ujson.loads(line.strip())
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_raw_input(description, default=False):
 | 
			
		||||
    """Get user input from the command line via raw_input / input.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user