* Hack on generic Language class. Still needs work for morphology, defaults, etc

This commit is contained in:
Matthew Honnibal 2015-08-26 19:16:09 +02:00
parent e2ef78b29c
commit 76996f4145

View File

@ -1,3 +1,19 @@
from os import path
from .tokenizer import Tokenizer
from .morphology import Morphology
from .vocab import Vocab
from .syntax.parser import Parser
from .tagger import Tagger
from .matcher import Matcher
from .serialize.packer import Packer
from ._ml import Model
from . import attrs
from . import orth
from .syntax.ner import BiluoPushDown
from .syntax.arc_eager import ArcEager
class Language(object):
@staticmethod
def lower(string):
@ -21,7 +37,7 @@ class Language(object):
@staticmethod
def prob(string):
return self.oov_prob
return -30
@staticmethod
def cluster(string):
@ -29,29 +45,50 @@ class Language(object):
@staticmethod
def is_alpha(string):
return orths.is_alpha(string)
return orth.is_alpha(string)
@staticmethod
def is_ascii(string):
return orth.is_ascii(string)
@staticmethod
def is_digit(string):
return string.isdigit()
@staticmethod
def is_lower(string):
return orths.is_lower(string)
return orth.is_lower(string)
@staticmethod
def is_punct(string):
return orth.is_punct(string)
@staticmethod
def is_space(string):
return string.isspace()
@staticmethod
def is_title(string):
return orth.is_title(string)
@staticmethod
def is_upper(string):
return orths.is_upper(string)
return orth.is_upper(string)
@staticmethod
def like_url(string):
return orths.like_url(string)
return orth.like_url(string)
@staticmethod
def like_number(string):
return orths.like_number(string)
return orth.like_number(string)
@staticmethod
def like_email(string):
return orths.like_email(string)
return orth.like_email(string)
def default_lex_attrs(cls, data_dir):
@classmethod
def default_lex_attrs(cls, data_dir=None):
return {
attrs.LOWER: cls.lower,
attrs.NORM: cls.norm,
@ -59,12 +96,15 @@ class Language(object):
attrs.PREFIX: cls.prefix,
attrs.SUFFIX: cls.suffix,
attrs.CLUSTER: cls.cluster,
attrs.PROB: cls.prob,
attrs.PROB: lambda string: -10.0,
attrs.IS_ALPHA: cls.is_alpha,
attrs.IS_ASCII: cls.is_ascii,
attrs.IS_DIGIT: cls.is_digit,
attrs.IS_LOWER: cls.is_lower,
attrs.IS_PUNCT: cls.is_punct,
attrs.IS_SPACE: cls.is_space,
attrs.IS_TITLE: cls.is_title,
attrs.IS_UPPER: cls.is_upper,
attrs.LIKE_URL: cls.like_url,
attrs.LIKE_NUM: cls.like_number,
@ -73,12 +113,36 @@ class Language(object):
attrs.IS_OOV: lambda string: True
}
@classmethod
def default_dep_templates(cls):
return []
@classmethod
def default_ner_templates(cls):
return []
@classmethod
def default_dep_labels(cls):
return {0: {'ROOT': True}}
@classmethod
def default_ner_labels(cls):
return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}}
@classmethod
def default_data_dir(cls):
return path.join(path.dirname(__file__), 'data')
@classmethod
def default_vocab(cls, get_lex_attr=None, vectors=None, morphology=None, data_dir=None):
def default_morphology(cls, data_dir):
return Morphology.from_dir(data_dir)
@classmethod
def default_vectors(cls, data_dir):
return None
@classmethod
def default_vocab(cls, data_dir=None, get_lex_attr=None, vectors=None, morphology=None):
if data_dir is None:
data_dir = cls.default_data_dir()
if vectors is None:
@ -86,70 +150,71 @@ class Language(object):
if get_lex_attr is None:
get_lex_attr = cls.default_lex_attrs(data_dir)
if morphology is None:
morphology = cls.default_morphology(data_dir)
return vocab = Vocab.from_dir(data_dir, get_lex_attr, vectors, morphology)
morphology = cls.default_morphology(path.join(data_dir, 'vocab'))
return Vocab.from_dir(
path.join(data_dir, 'vocab'),
get_lex_attr=get_lex_attr,
vectors=vectors,
morphology=morphology)
@classmethod
def default_tokenizer(cls, vocab, data_dir=None):
if data_dir is None:
data_dir = cls.default_data_dir()
return Tokenizer.from_dir(data_dir, vocab)
def default_tokenizer(cls, vocab, data_dir):
if path.exists(data_dir):
return Tokenizer.from_dir(vocab, data_dir)
else:
return Tokenizer(vocab, {}, None, None, None)
@classmethod
def default_tagger(cls, vocab, data_dir=None):
return Tagger.from_dir(data_dir, vocab)
def default_tagger(cls, vocab, data_dir):
if path.exists(data_dir):
return Tagger.from_dir(data_dir, vocab)
else:
return None
@classmethod
def default_parser(cls, vocab, transition_system=None, data_dir=None):
if transition_system is None:
transition_system = ArcEager()
return Parser.from_dir(data_dir, vocab, transition_system)
def default_parser(cls, vocab, data_dir):
if path.exists(data_dir):
return Parser.from_dir(data_dir, vocab.strings, ArcEager)
else:
return None
@classmethod
def default_entity(cls, vocab, transition_system=None, data_dir=None):
if transition_system is None:
transition_system = BiluoPushDown()
return Parser.from_dir(data_dir, vocab, transition_system)
def default_entity(cls, vocab, data_dir):
if path.exists(data_dir):
return Parser.from_dir(data_dir, vocab.strings, BiluoPushDown)
else:
return None
@classmethod
def default_matcher(cls, vocab, data_dir=None):
if data_dir is None:
data_dir = cls.default_data_dir()
return Matcher(data_dir, vocab)
return Matcher.from_dir(data_dir, vocab)
@classmethod
def default_serializer(cls, vocab, data_dir=None):
if data_dir is None:
data_dir = cls.default_data_dir()
return Packer(data_dir, vocab)
def __init__(self, vocab=None, tokenizer=None, tagger=None, parser=None,
entity=None, matcher=None, serializer=None):
def __init__(self, data_dir=None, vocab=None, tokenizer=None, tagger=None,
parser=None, entity=None, matcher=None, serializer=None):
if data_dir is None:
data_dir = self.default_data_dir()
if vocab is None:
vocab = self.default_vocab(data_dir)
if tokenizer is None:
tokenizer = self.default_tokenizer(vocab, data_dir)
tokenizer = self.default_tokenizer(vocab, data_dir=path.join(data_dir, 'tokenizer'))
if tagger is None:
tagger = self.default_tagger(vocab, data_dir)
tagger = self.default_tagger(vocab, data_dir=path.join(data_dir, 'pos'))
if entity is None:
entity = self.default_entity(vocab, data_dir)
entity = self.default_entity(vocab, data_dir=path.join(data_dir, 'ner'))
if parser is None:
parser = self.default_parser(vocab, data_dir)
parser = self.default_parser(vocab, data_dir=path.join(data_dir, 'deps'))
if matcher is None:
matcher = self.default_matcher(vocab, data_dir)
if serializer is None:
serializer = self.default_serializer(vocab, data_dir)
matcher = self.default_matcher(vocab, data_dir=data_dir)
self.vocab = vocab
self.tokenizer = tokenizer
self.tagger = tagger
self.parser = parser
self.entity = entity
self.matcher = matcher
self.serializer = serializer
def __call__(self, text, tag=True, parse=True, entity=True):
def __call__(self, text, tag=True, parse=True, entity=True, merge_mwes=False):
"""Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string
is preserved.