* Hack on generic Language class. Still needs work for morphology, defaults, etc

This commit is contained in:
Matthew Honnibal 2015-08-26 19:16:09 +02:00
parent e2ef78b29c
commit 76996f4145

View File

@ -1,3 +1,19 @@
from os import path
from .tokenizer import Tokenizer
from .morphology import Morphology
from .vocab import Vocab
from .syntax.parser import Parser
from .tagger import Tagger
from .matcher import Matcher
from .serialize.packer import Packer
from ._ml import Model
from . import attrs
from . import orth
from .syntax.ner import BiluoPushDown
from .syntax.arc_eager import ArcEager
class Language(object): class Language(object):
@staticmethod @staticmethod
def lower(string): def lower(string):
@ -21,7 +37,7 @@ class Language(object):
@staticmethod @staticmethod
def prob(string): def prob(string):
return self.oov_prob return -30
@staticmethod @staticmethod
def cluster(string): def cluster(string):
@ -29,29 +45,50 @@ class Language(object):
@staticmethod @staticmethod
def is_alpha(string): def is_alpha(string):
return orths.is_alpha(string) return orth.is_alpha(string)
@staticmethod
def is_ascii(string):
return orth.is_ascii(string)
@staticmethod
def is_digit(string):
return string.isdigit()
@staticmethod @staticmethod
def is_lower(string): def is_lower(string):
return orths.is_lower(string) return orth.is_lower(string)
@staticmethod
def is_punct(string):
return orth.is_punct(string)
@staticmethod
def is_space(string):
return string.isspace()
@staticmethod
def is_title(string):
return orth.is_title(string)
@staticmethod @staticmethod
def is_upper(string): def is_upper(string):
return orths.is_upper(string) return orth.is_upper(string)
@staticmethod @staticmethod
def like_url(string): def like_url(string):
return orths.like_url(string) return orth.like_url(string)
@staticmethod @staticmethod
def like_number(string): def like_number(string):
return orths.like_number(string) return orth.like_number(string)
@staticmethod @staticmethod
def like_email(string): def like_email(string):
return orths.like_email(string) return orth.like_email(string)
def default_lex_attrs(cls, data_dir): @classmethod
def default_lex_attrs(cls, data_dir=None):
return { return {
attrs.LOWER: cls.lower, attrs.LOWER: cls.lower,
attrs.NORM: cls.norm, attrs.NORM: cls.norm,
@ -59,12 +96,15 @@ class Language(object):
attrs.PREFIX: cls.prefix, attrs.PREFIX: cls.prefix,
attrs.SUFFIX: cls.suffix, attrs.SUFFIX: cls.suffix,
attrs.CLUSTER: cls.cluster, attrs.CLUSTER: cls.cluster,
attrs.PROB: cls.prob, attrs.PROB: lambda string: -10.0,
attrs.IS_ALPHA: cls.is_alpha, attrs.IS_ALPHA: cls.is_alpha,
attrs.IS_ASCII: cls.is_ascii, attrs.IS_ASCII: cls.is_ascii,
attrs.IS_DIGIT: cls.is_digit, attrs.IS_DIGIT: cls.is_digit,
attrs.IS_LOWER: cls.is_lower, attrs.IS_LOWER: cls.is_lower,
attrs.IS_PUNCT: cls.is_punct,
attrs.IS_SPACE: cls.is_space,
attrs.IS_TITLE: cls.is_title,
attrs.IS_UPPER: cls.is_upper, attrs.IS_UPPER: cls.is_upper,
attrs.LIKE_URL: cls.like_url, attrs.LIKE_URL: cls.like_url,
attrs.LIKE_NUM: cls.like_number, attrs.LIKE_NUM: cls.like_number,
@ -73,12 +113,36 @@ class Language(object):
attrs.IS_OOV: lambda string: True attrs.IS_OOV: lambda string: True
} }
@classmethod
def default_dep_templates(cls):
return []
@classmethod
def default_ner_templates(cls):
return []
@classmethod
def default_dep_labels(cls):
return {0: {'ROOT': True}}
@classmethod
def default_ner_labels(cls):
return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}}
@classmethod @classmethod
def default_data_dir(cls): def default_data_dir(cls):
return path.join(path.dirname(__file__), 'data') return path.join(path.dirname(__file__), 'data')
@classmethod @classmethod
def default_vocab(cls, get_lex_attr=None, vectors=None, morphology=None, data_dir=None): def default_morphology(cls, data_dir):
return Morphology.from_dir(data_dir)
@classmethod
def default_vectors(cls, data_dir):
return None
@classmethod
def default_vocab(cls, data_dir=None, get_lex_attr=None, vectors=None, morphology=None):
if data_dir is None: if data_dir is None:
data_dir = cls.default_data_dir() data_dir = cls.default_data_dir()
if vectors is None: if vectors is None:
@ -86,70 +150,71 @@ class Language(object):
if get_lex_attr is None: if get_lex_attr is None:
get_lex_attr = cls.default_lex_attrs(data_dir) get_lex_attr = cls.default_lex_attrs(data_dir)
if morphology is None: if morphology is None:
morphology = cls.default_morphology(data_dir) morphology = cls.default_morphology(path.join(data_dir, 'vocab'))
return vocab = Vocab.from_dir(data_dir, get_lex_attr, vectors, morphology) return Vocab.from_dir(
path.join(data_dir, 'vocab'),
get_lex_attr=get_lex_attr,
vectors=vectors,
morphology=morphology)
@classmethod @classmethod
def default_tokenizer(cls, vocab, data_dir=None): def default_tokenizer(cls, vocab, data_dir):
if data_dir is None: if path.exists(data_dir):
data_dir = cls.default_data_dir() return Tokenizer.from_dir(vocab, data_dir)
return Tokenizer.from_dir(data_dir, vocab) else:
return Tokenizer(vocab, {}, None, None, None)
@classmethod @classmethod
def default_tagger(cls, vocab, data_dir=None): def default_tagger(cls, vocab, data_dir):
return Tagger.from_dir(data_dir, vocab) if path.exists(data_dir):
return Tagger.from_dir(data_dir, vocab)
else:
return None
@classmethod @classmethod
def default_parser(cls, vocab, transition_system=None, data_dir=None): def default_parser(cls, vocab, data_dir):
if transition_system is None: if path.exists(data_dir):
transition_system = ArcEager() return Parser.from_dir(data_dir, vocab.strings, ArcEager)
return Parser.from_dir(data_dir, vocab, transition_system) else:
return None
@classmethod @classmethod
def default_entity(cls, vocab, transition_system=None, data_dir=None): def default_entity(cls, vocab, data_dir):
if transition_system is None: if path.exists(data_dir):
transition_system = BiluoPushDown() return Parser.from_dir(data_dir, vocab.strings, BiluoPushDown)
return Parser.from_dir(data_dir, vocab, transition_system) else:
return None
@classmethod @classmethod
def default_matcher(cls, vocab, data_dir=None): def default_matcher(cls, vocab, data_dir=None):
if data_dir is None: if data_dir is None:
data_dir = cls.default_data_dir() data_dir = cls.default_data_dir()
return Matcher(data_dir, vocab) return Matcher.from_dir(data_dir, vocab)
@classmethod def __init__(self, data_dir=None, vocab=None, tokenizer=None, tagger=None,
def default_serializer(cls, vocab, data_dir=None): parser=None, entity=None, matcher=None, serializer=None):
if data_dir is None:
data_dir = cls.default_data_dir()
return Packer(data_dir, vocab)
def __init__(self, vocab=None, tokenizer=None, tagger=None, parser=None,
entity=None, matcher=None, serializer=None):
if data_dir is None: if data_dir is None:
data_dir = self.default_data_dir() data_dir = self.default_data_dir()
if vocab is None: if vocab is None:
vocab = self.default_vocab(data_dir) vocab = self.default_vocab(data_dir)
if tokenizer is None: if tokenizer is None:
tokenizer = self.default_tokenizer(vocab, data_dir) tokenizer = self.default_tokenizer(vocab, data_dir=path.join(data_dir, 'tokenizer'))
if tagger is None: if tagger is None:
tagger = self.default_tagger(vocab, data_dir) tagger = self.default_tagger(vocab, data_dir=path.join(data_dir, 'pos'))
if entity is None: if entity is None:
entity = self.default_entity(vocab, data_dir) entity = self.default_entity(vocab, data_dir=path.join(data_dir, 'ner'))
if parser is None: if parser is None:
parser = self.default_parser(vocab, data_dir) parser = self.default_parser(vocab, data_dir=path.join(data_dir, 'deps'))
if matcher is None: if matcher is None:
matcher = self.default_matcher(vocab, data_dir) matcher = self.default_matcher(vocab, data_dir=data_dir)
if serializer is None:
serializer = self.default_serializer(vocab, data_dir)
self.vocab = vocab self.vocab = vocab
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tagger = tagger self.tagger = tagger
self.parser = parser self.parser = parser
self.entity = entity self.entity = entity
self.matcher = matcher self.matcher = matcher
self.serializer = serializer
def __call__(self, text, tag=True, parse=True, entity=True): def __call__(self, text, tag=True, parse=True, entity=True, merge_mwes=False):
"""Apply the pipeline to some text. The text can span multiple sentences, """Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string and can contain arbtrary whitespace. Alignment into the original string
is preserved. is preserved.