diff --git a/spacy/en/__init__.py b/spacy/en/__init__.py index 5bf83a253..c81630a72 100644 --- a/spacy/en/__init__.py +++ b/spacy/en/__init__.py @@ -11,6 +11,7 @@ from ..syntax.arc_eager import ArcEager from ..syntax.ner import BiluoPushDown from ..syntax.parser import ParserFactory from ..serialize.bits import BitArray +from ..matcher import Matcher from ..tokens import Doc from ..multi_words import RegexMerger @@ -75,6 +76,7 @@ class English(object): Tagger=EnPosTagger, Parser=ParserFactory(ParserTransitionSystem), Entity=ParserFactory(EntityTransitionSystem), + Matcher=Matcher.from_dir, Packer=None, load_vectors=True ): @@ -113,6 +115,10 @@ class English(object): self.entity = Entity(self.vocab.strings, path.join(data_dir, 'ner')) else: self.entity = None + if Matcher: + self.matcher = Matcher(self.vocab, data_dir) + else: + self.matcher = None if Packer: self.packer = Packer(self.vocab, data_dir) else: @@ -143,6 +149,8 @@ class English(object): tokens = self.tokenizer(text) if self.tagger and tag: self.tagger(tokens) + if self.matcher and entity: + self.matcher(tokens) if self.parser and parse: self.parser(tokens) if self.entity and entity: diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index b1b77e162..ab3ef354b 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -1,3 +1,5 @@ +from os import path + from .typedefs cimport attr_t from .attrs cimport attr_id_t from .structs cimport TokenC @@ -5,11 +7,16 @@ from .structs cimport TokenC from cymem.cymem cimport Pool from libcpp.vector cimport vector -from .attrs cimport LENGTH, ENT_TYPE +from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE from .tokens.doc cimport get_token_attr from .tokens.doc cimport Doc from .vocab cimport Vocab +try: + import ujson as json +except ImportError: + import json + cdef struct AttrValue: attr_id_t attr @@ -58,18 +65,61 @@ cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i): return (pattern.spec[0].value, i - pattern.spec[1].value, i) +def _convert_strings(token_specs, string_store): + converted = [] + for spec in token_specs: + converted.append([]) + for attr, value in spec.items(): + if isinstance(attr, basestring): + attr = map_attr_name(attr) + if isinstance(value, basestring): + value = string_store[value] + converted[-1].append((attr, value)) + return converted + + +def map_attr_name(attr): + attr = attr.upper() + if attr == 'ORTH': + return ORTH + elif attr == 'LEMMA': + return LEMMA + elif attr == 'LOWER': + return LOWER + elif attr == 'SHAOE': + return SHAPE + elif attr == 'NORM': + return NORM + else: + raise Exception("TODO: Finish supporting attr mapping %s" % attr) + + cdef class Matcher: cdef Pool mem cdef Pattern** patterns cdef readonly int n_patterns - def __init__(self, patterns): + def __init__(self, vocab, patterns): self.mem = Pool() self.patterns = self.mem.alloc(len(patterns), sizeof(Pattern*)) - for i, (token_specs, entity_type) in enumerate(patterns): - self.patterns[i] = init_pattern(self.mem, token_specs, entity_type) + for i, (entity_key, (etype, attrs, specs)) in enumerate(sorted(patterns.items())): + if isinstance(entity_key, basestring): + entity_key = vocab.strings[entity_key] + if isinstance(etype, basestring): + etype = vocab.strings[etype] + specs = _convert_strings(specs, vocab.strings) + self.patterns[i] = init_pattern(self.mem, specs, etype) self.n_patterns = len(patterns) + @classmethod + def from_dir(cls, vocab, data_dir): + patterns_loc = path.join(data_dir, 'ner', 'patterns.json') + if path.exists(patterns_loc): + patterns = json.loads(open(patterns_loc)) + return cls(vocab, patterns) + else: + return cls(vocab, {}) + def __call__(self, Doc doc): cdef vector[Pattern*] partials cdef int n_partials = 0