* Use language base class

This commit is contained in:
Matthew Honnibal 2015-08-25 15:37:30 +02:00
parent f2f699ac18
commit 8083a07c3e

View File

@ -1,183 +1,11 @@
from __future__ import unicode_literals from __future__ import unicode_literals, print_function
from os import path from os import path
import re
import struct
import json
from .. import orth from ..language import Language
from ..vocab import Vocab
from ..tokenizer import Tokenizer
from ..syntax.arc_eager import ArcEager
from ..syntax.ner import BiluoPushDown
from ..syntax.parser import ParserFactory
from ..serialize.bits import BitArray
from ..matcher import Matcher
from ..tokens import Doc
from ..multi_words import RegexMerger
from .pos import EnPosTagger
from .pos import POS_TAGS
from .attrs import get_flags
from . import regexes
from ..util import read_lang_data
from ..attrs import TAG, HEAD, DEP, ENT_TYPE, ENT_IOB
def get_lex_props(string, oov_prob=-30, is_oov=False): class English(Language):
return { @classmethod
'flags': get_flags(string, is_oov=is_oov), def default_data_dir(cls):
'length': len(string), return path.join(path.dirname(__file__), 'data')
'orth': string,
'lower': string.lower(),
'norm': string,
'shape': orth.word_shape(string),
'prefix': string[0],
'suffix': string[-3:],
'cluster': 0,
'prob': oov_prob,
'sentiment': 0
}
get_lex_attr = {}
if_model_present = -1
LOCAL_DATA_DIR = path.join(path.dirname(__file__), 'data')
class English(object):
"""The English NLP pipeline.
Example:
Load data from default directory:
>>> nlp = English()
>>> nlp = English(data_dir=u'')
Load data from specified directory:
>>> nlp = English(data_dir=u'path/to/data_directory')
Disable (and avoid loading) parts of the processing pipeline:
>>> nlp = English(vectors=False, parser=False, tagger=False, entity=False)
Start with nothing loaded:
>>> nlp = English(data_dir=None)
"""
ParserTransitionSystem = ArcEager
EntityTransitionSystem = BiluoPushDown
def __init__(self,
data_dir=LOCAL_DATA_DIR,
Tokenizer=Tokenizer.from_dir,
Tagger=EnPosTagger,
Parser=ParserFactory(ParserTransitionSystem),
Entity=ParserFactory(EntityTransitionSystem),
Matcher=Matcher.from_dir,
Packer=None,
load_vectors=True
):
self.data_dir = data_dir
if path.exists(path.join(data_dir, 'vocab', 'oov_prob')):
oov_prob = float(open(path.join(data_dir, 'vocab', 'oov_prob')).read())
else:
oov_prob = None
self.vocab = Vocab(data_dir=path.join(data_dir, 'vocab') if data_dir else None,
get_lex_props=get_lex_props, load_vectors=load_vectors,
pos_tags=POS_TAGS,
oov_prob=oov_prob)
if Tagger is True:
Tagger = EnPosTagger
if Parser is True:
transition_system = self.ParserTransitionSystem
Parser = lambda s, d: parser.Parser(s, d, transition_system)
if Entity is True:
transition_system = self.EntityTransitionSystem
Entity = lambda s, d: parser.Parser(s, d, transition_system)
self.tokenizer = Tokenizer(self.vocab, path.join(data_dir, 'tokenizer'))
if Tagger and path.exists(path.join(data_dir, 'pos')):
self.tagger = Tagger(self.vocab.strings, data_dir)
else:
self.tagger = None
if Parser and path.exists(path.join(data_dir, 'deps')):
self.parser = Parser(self.vocab.strings, path.join(data_dir, 'deps'))
else:
self.parser = None
if Entity and path.exists(path.join(data_dir, 'ner')):
self.entity = Entity(self.vocab.strings, path.join(data_dir, 'ner'))
else:
self.entity = None
if Matcher:
self.matcher = Matcher(self.vocab, data_dir)
else:
self.matcher = None
if Packer:
self.packer = Packer(self.vocab, data_dir)
else:
self.packer = None
self.mwe_merger = RegexMerger([
('IN', 'O', regexes.MW_PREPOSITIONS_RE),
('CD', 'TIME', regexes.TIME_RE),
('NNP', 'DATE', regexes.DAYS_RE),
('CD', 'MONEY', regexes.MONEY_RE)])
def __call__(self, text, tag=True, parse=True, entity=True, merge_mwes=False):
"""Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string
is preserved.
Args:
text (unicode): The text to be processed.
Returns:
tokens (spacy.tokens.Doc):
>>> from spacy.en import English
>>> nlp = English()
>>> tokens = nlp('An example sentence. Another example sentence.')
>>> tokens[0].orth_, tokens[0].head.tag_
('An', 'NN')
"""
tokens = self.tokenizer(text)
if self.tagger and tag:
self.tagger(tokens)
if self.matcher and entity:
self.matcher(tokens)
if self.parser and parse:
self.parser(tokens)
if self.entity and entity:
self.entity(tokens)
if merge_mwes and self.mwe_merger is not None:
self.mwe_merger(tokens)
return tokens
def end_training(self, data_dir=None):
if data_dir is None:
data_dir = self.data_dir
self.parser.model.end_training()
self.entity.model.end_training()
self.tagger.model.end_training()
self.vocab.strings.dump(path.join(data_dir, 'vocab', 'strings.txt'))
with open(path.join(data_dir, 'vocab', 'serializer.json'), 'w') as file_:
file_.write(
json.dumps([
(TAG, list(self.tagger.freqs[TAG].items())),
(DEP, list(self.parser.moves.freqs[DEP].items())),
(ENT_IOB, list(self.entity.moves.freqs[ENT_IOB].items())),
(ENT_TYPE, list(self.entity.moves.freqs[ENT_TYPE].items())),
(HEAD, list(self.parser.moves.freqs[HEAD].items()))]))
@property
def tags(self):
"""Deprecated. List of part-of-speech tag names."""
return self.tagger.tag_names