Mostly finished loading refactoring. Design is in place, but doesn't work yet.

This commit is contained in:
Matthew Honnibal 2016-09-24 15:42:01 +02:00
parent 9dc8043a7e
commit 83e364188c
14 changed files with 265 additions and 295 deletions

View File

@ -1,23 +1,38 @@
from .util import set_lang_class, get_lang_class, get_package, get_package_by_name import pathlib
from .util import set_lang_class, get_lang_class
from . import en from . import en
from . import de from . import de
from . import zh from . import zh
_data_path = pathlib.Path(__file__).parent / 'data'
set_lang_class(en.English.lang, en.English) set_lang_class(en.English.lang, en.English)
set_lang_class(de.German.lang, de.German) set_lang_class(de.German.lang, de.German)
set_lang_class(zh.Chinese.lang, zh.Chinese) set_lang_class(zh.Chinese.lang, zh.Chinese)
def get_data_path():
return _data_path
def set_data_path(path):
global _data_path
if isinstance(path, basestring):
path = pathlib.Path(path)
_data_path = path
def load(name, vocab=None, tokenizer=None, parser=None, tagger=None, entity=None, def load(name, vocab=None, tokenizer=None, parser=None, tagger=None, entity=None,
matcher=None, serializer=None, vectors=None, via=None): matcher=None, serializer=None, vectors=None, via=None):
package = get_package_by_name(name, via=via) if via is None:
vectors_package = get_package_by_name(vectors, via=via) via = get_data_path()
cls = get_lang_class(name) cls = get_lang_class(name)
return cls( return cls(
package=package, via,
vectors_package=vectors_package, vectors=vectors,
vocab=vocab, vocab=vocab,
tokenizer=tokenizer, tokenizer=tokenizer,
tagger=tagger, tagger=tagger,

99
spacy/deprecated.py Normal file
View File

@ -0,0 +1,99 @@
from sputnik.dir_package import DirPackage
from sputnik.package_list import (PackageNotFoundException,
CompatiblePackageNotFoundException)
import sputnik
from . import about
def get_package(data_dir):
if not isinstance(data_dir, six.string_types):
raise RuntimeError('data_dir must be a string')
return DirPackage(data_dir)
def get_package_by_name(name=None, via=None):
if name is None:
return
lang = get_lang_class(name)
try:
return sputnik.package(about.__title__, about.__version__,
name, data_path=via)
except PackageNotFoundException as e:
raise RuntimeError("Model '%s' not installed. Please run 'python -m "
"%s.download' to install latest compatible "
"model." % (name, lang.__module__))
except CompatiblePackageNotFoundException as e:
raise RuntimeError("Installed model is not compatible with spaCy "
"version. Please run 'python -m %s.download "
"--force' to install latest compatible model." %
(lang.__module__))
def read_lang_data(package):
tokenization = package.load_json(('tokenizer', 'specials.json'))
with package.open(('tokenizer', 'prefix.txt'), default=None) as file_:
prefix = read_prefix(file_) if file_ is not None else None
with package.open(('tokenizer', 'suffix.txt'), default=None) as file_:
suffix = read_suffix(file_) if file_ is not None else None
with package.open(('tokenizer', 'infix.txt'), default=None) as file_:
infix = read_infix(file_) if file_ is not None else None
return tokenization, prefix, suffix, infix
def read_prefix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
return expression
def read_suffix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
return expression
def read_infix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join([piece for piece in entries if piece.strip()])
return expression
def align_tokens(ref, indices): # Deprecated, surely?
start = 0
queue = list(indices)
for token in ref:
end = start + len(token)
emit = []
while queue and queue[0][1] <= end:
emit.append(queue.pop(0))
yield token, emit
start = end
assert not queue
def detokenize(token_rules, words): # Deprecated?
"""To align with treebanks, return a list of "chunks", where a chunk is a
sequence of tokens that are separated by whitespace in actual strings. Each
chunk should be a tuple of token indices, e.g.
>>> detokenize(["ca<SEP>n't", '<SEP>!'], ["I", "ca", "n't", "!"])
[(0,), (1, 2, 3)]
"""
string = ' '.join(words)
for subtoks in token_rules:
# Algorithmically this is dumb, but writing a little list-based match
# machine? Ain't nobody got time for that.
string = string.replace(subtoks.replace('<SEP>', ' '), subtoks)
positions = []
i = 0
for chunk in string.split():
subtoks = chunk.split('<SEP>')
positions.append(tuple(range(i, i+len(subtoks))))
i += len(subtoks)
return positions

View File

@ -25,13 +25,19 @@ class Defaults(object):
def __init__(self, lang, path): def __init__(self, lang, path):
self.lang = lang self.lang = lang
self.path = path self.path = path
self.lex_attr_getters = dict(self.__class__.lex_attr_getters)
if (self.path / 'vocab' / 'oov_prob').exists():
with (self.path / 'vocab' / 'oov_prob').open() as file_:
oov_prob = file_.read().strip()
self.lex_attr_getters['PROB'] = lambda string: oov_prob
self.lex_attr_getters['LANG'] = lambda string: self.lang,
def Vectors(self): def Vectors(self):
pass pass
def Vocab(self, vectors=None, get_lex_attr=None): def Vocab(self, vectors=None, lex_attr_getters=None):
if get_lex_attr is None: if lex_attr_getters is None:
get_lex_attr = self.lex_attrs() lex_attr_getters = dict(self.lex_attr_getters)
if vectors is None: if vectors is None:
vectors = self.Vectors() vectors = self.Vectors()
return Vocab.load(self.path, get_lex_attr=get_lex_attr, vectors=vectors) return Vocab.load(self.path, get_lex_attr=get_lex_attr, vectors=vectors)
@ -64,84 +70,42 @@ class Defaults(object):
nlp.parser, nlp.parser,
nlp.entity] nlp.entity]
def dep_labels(self): dep_labels = {0: {'ROOT': True}}
return {0: {'ROOT': True}}
def ner_labels(self): ner_labels = {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}}
return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}}
def lex_attrs(self, *args, **kwargs): lex_attr_getters = {
if 'oov_prob' in kwargs: attrs.LOWER: lambda string: string.lower(),
oov_prob = kwargs.get('oov_prob', -20) attrs.NORM: lambda string: string,
else: attrs.SHAPE: orth.word_shape,
with (self.path / 'vocab' / 'oov_prob').open() as file_: attrs.PREFIX: lambda string: string[0],
oov_prob = file_.read().strip() attrs.SUFFIX: lambda string: string[-3:],
return { attrs.CLUSTER: lambda string: 0,
attrs.LOWER: self.lower, attrs.IS_ALPHA: orth.is_alpha,
attrs.NORM: self.norm, attrs.IS_ASCII: orth.is_ascii,
attrs.SHAPE: orth.word_shape, attrs.IS_DIGIT: lambda string: string.isdigit(),
attrs.PREFIX: self.prefix, attrs.IS_LOWER: orth.is_lower,
attrs.SUFFIX: self.suffix, attrs.IS_PUNCT: orth.is_punct,
attrs.CLUSTER: self.cluster, attrs.IS_SPACE: lambda string: string.isspace(),
attrs.PROB: lambda string: oov_prob, attrs.IS_TITLE: orth.is_title,
attrs.LANG: lambda string: self.lang, attrs.IS_UPPER: orth.is_upper,
attrs.IS_ALPHA: orth.is_alpha, attrs.IS_BRACKET: orth.is_bracket,
attrs.IS_ASCII: orth.is_ascii, attrs.IS_QUOTE: orth.is_quote,
attrs.IS_DIGIT: self.is_digit, attrs.IS_LEFT_PUNCT: orth.is_left_punct,
attrs.IS_LOWER: orth.is_lower, attrs.IS_RIGHT_PUNCT: orth.is_right_punct,
attrs.IS_PUNCT: orth.is_punct, attrs.LIKE_URL: orth.like_url,
attrs.IS_SPACE: self.is_space, attrs.LIKE_NUM: orth.like_number,
attrs.IS_TITLE: orth.is_title, attrs.LIKE_EMAIL: orth.like_email,
attrs.IS_UPPER: orth.is_upper, attrs.IS_STOP: lambda string: False,
attrs.IS_BRACKET: orth.is_bracket, attrs.IS_OOV: lambda string: True
attrs.IS_QUOTE: orth.is_quote, }
attrs.IS_LEFT_PUNCT: orth.is_left_punct,
attrs.IS_RIGHT_PUNCT: orth.is_right_punct,
attrs.LIKE_URL: orth.like_url,
attrs.LIKE_NUM: orth.like_number,
attrs.LIKE_EMAIL: orth.like_email,
attrs.IS_STOP: self.is_stop,
attrs.IS_OOV: lambda string: True
}
@staticmethod
def lower(string):
return string.lower()
@staticmethod
def norm(string):
return string
@staticmethod
def prefix(string):
return string[0]
@staticmethod
def suffix(string):
return string[-3:]
@staticmethod
def cluster(string):
return 0
@staticmethod
def is_digit(string):
return string.isdigit()
@staticmethod
def is_space(string):
return string.isspace()
@staticmethod
def is_stop(string):
return 0
class Language(object): class Language(object):
'''A text-processing pipeline. Usually you'll load this once per process, and '''A text-processing pipeline. Usually you'll load this once per process, and
pass the instance around your program. pass the instance around your program.
''' '''
Defaults = Defaults
lang = None lang = None
def __init__(self, def __init__(self,
@ -180,6 +144,7 @@ class Language(object):
path = data_dir path = data_dir
if isinstance(path, basestring): if isinstance(path, basestring):
path = pathlib.Path(path) path = pathlib.Path(path)
self.path = path
defaults = defaults if defaults is not True else self.get_defaults(self.path) defaults = defaults if defaults is not True else self.get_defaults(self.path)
self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors) self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors)
@ -291,4 +256,4 @@ class Language(object):
def get_defaults(self, path): def get_defaults(self, path):
return Defaults(path) return Defaults(self.lang, path)

View File

@ -8,7 +8,6 @@ except ImportError:
import json import json
from .parts_of_speech import NOUN, VERB, ADJ, PUNCT from .parts_of_speech import NOUN, VERB, ADJ, PUNCT
from .util import get_package
class Lemmatizer(object): class Lemmatizer(object):

View File

@ -23,7 +23,6 @@ from .tokens.doc cimport Doc
from .vocab cimport Vocab from .vocab cimport Vocab
from .attrs import FLAG61 as U_ENT from .attrs import FLAG61 as U_ENT
from .util import get_package
from .attrs import FLAG60 as B2_ENT from .attrs import FLAG60 as B2_ENT
from .attrs import FLAG59 as B3_ENT from .attrs import FLAG59 as B3_ENT
@ -195,14 +194,12 @@ cdef class Matcher:
cdef vector[TokenPatternC*] patterns cdef vector[TokenPatternC*] patterns
cdef readonly Vocab vocab cdef readonly Vocab vocab
cdef public object _patterns cdef public object _patterns
@classmethod @classmethod
def load(cls, data_dir, Vocab vocab): def load(cls, path, vocab):
return cls.from_package(get_package(data_dir), vocab=vocab) if (path / 'patterns.json').exists():
with (path / 'patterns.json').open() as file_:
@classmethod patterns = json.load(file_)
def from_package(cls, package, Vocab vocab):
patterns = package.load_json(('vocab', 'gazetteer.json'))
return cls(vocab, patterns) return cls(vocab, patterns)
def __init__(self, vocab, patterns={}): def __init__(self, vocab, patterns={}):

View File

@ -4,6 +4,7 @@ from thinc.structs cimport ExampleC
from .stateclass cimport StateClass from .stateclass cimport StateClass
from .arc_eager cimport TransitionSystem from .arc_eager cimport TransitionSystem
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..structs cimport TokenC from ..structs cimport TokenC
from ._state cimport StateC from ._state cimport StateC
@ -13,8 +14,8 @@ cdef class ParserModel(AveragedPerceptron):
cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil
cdef class Parser: cdef class Parser:
cdef readonly Vocab vocab
cdef readonly ParserModel model cdef readonly ParserModel model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef int _projectivize
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil

View File

@ -78,34 +78,24 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser: cdef class Parser:
def __init__(self, StringStore strings, transition_system, ParserModel model, int projectivize = 0): @classmethod
def load(cls, path, Vocab vocab, moves_class):
with (path / 'config.json').open() as file_:
cfg = json.loads(file_)
moves = moves_class(vocab.strings, cfg['labels'])
templates = get_templates(cfg['features'])
model = ParserModel(templates)
if (path / 'model').exists():
model.load(path / 'model')
return cls(vocab, moves, model, **cfg)
def __init__(self, Vocab vocab, transition_system, ParserModel model, **cfg):
self.moves = transition_system self.moves = transition_system
self.model = model self.model = model
self._projectivize = projectivize self.cfg = cfg
@classmethod
def from_dir(cls, model_dir, strings, transition_system):
if not os.path.exists(model_dir):
print >> sys.stderr, "Warning: No model found at", model_dir
elif not os.path.isdir(model_dir):
print >> sys.stderr, "Warning: model path:", model_dir, "is not a directory"
cfg = Config.read(model_dir, 'config')
moves = transition_system(strings, cfg.labels)
templates = get_templates(cfg.features)
model = ParserModel(templates)
project = cfg.projectivize if hasattr(cfg,'projectivize') else False
if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model, project)
@classmethod
def load(cls, pkg_or_str_or_file, vocab):
# TODO
raise NotImplementedError(
"This should be here, but isn't yet =/. Use Parser.from_dir")
def __reduce__(self): def __reduce__(self):
return (Parser, (self.moves.strings, self.moves, self.model), None, None) return (Parser, (self.vocab, self.moves, self.model), None, None)
def __call__(self, Doc tokens): def __call__(self, Doc tokens):
cdef int nr_class = self.moves.n_moves cdef int nr_class = self.moves.n_moves

View File

@ -18,8 +18,6 @@ from .parts_of_speech cimport VERB, X, PUNCT, EOL, SPACE
from .attrs cimport * from .attrs cimport *
from .util import get_package
cpdef enum: cpdef enum:
P2_orth P2_orth
@ -147,24 +145,21 @@ cdef class Tagger:
return cls(vocab, model) return cls(vocab, model)
@classmethod @classmethod
def load(cls, data_dir, vocab): def load(cls, path, vocab):
return cls.from_package(get_package(data_dir), vocab=vocab) if (path / 'pos' / 'templates.json').exists():
with (path / 'pos' / 'templates.json').open() as file_:
@classmethod templates = json.load(file_)
def from_package(cls, pkg, vocab): else:
# TODO: templates.json deprecated? not present in latest package templates = cls.default_templates()
# templates = cls.default_templates()
templates = pkg.load_json(('pos', 'templates.json'), default=cls.default_templates())
model = TaggerModel(templates) model = TaggerModel(templates)
if pkg.has_file('pos', 'model'): if (path / 'pos' / 'model').exists():
model.load(pkg.file_path('pos', 'model')) model.load(path / 'pos' / 'model')
return cls(vocab, model) return cls(vocab, model)
def __init__(self, Vocab vocab, TaggerModel model): def __init__(self, Vocab vocab, TaggerModel model):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
# TODO: Move this to tag map # TODO: Move this to tag map
self.freqs = {TAG: defaultdict(int)} self.freqs = {TAG: defaultdict(int)}
for tag in self.tag_names: for tag in self.tag_names:

View File

@ -1,4 +1,4 @@
from spacy.util import align_tokens from spacy.deprecated import align_tokens
def test_perfect_align(): def test_perfect_align():

View File

@ -1,4 +1,4 @@
from spacy.util import detokenize from spacy.deprecated import detokenize
def test_punct(): def test_punct():
tokens = 'Pierre Vinken , 61 years old .'.split() tokens = 'Pierre Vinken , 61 years old .'.split()

View File

@ -16,7 +16,7 @@ def matcher():
'GoogleNow': ['PRODUCT', {}, [[{'ORTH': 'Google'}, {'ORTH': 'Now'}]]], 'GoogleNow': ['PRODUCT', {}, [[{'ORTH': 'Google'}, {'ORTH': 'Now'}]]],
'Java': ['PRODUCT', {}, [[{'LOWER': 'java'}]]], 'Java': ['PRODUCT', {}, [[{'LOWER': 'java'}]]],
} }
return Matcher(Vocab(get_lex_attr=English.default_lex_attrs()), patterns) return Matcher(Vocab(get_lex_attr=English.Defaults.lex_attr_getters), patterns)
def test_compile(matcher): def test_compile(matcher):

View File

@ -1,13 +1,20 @@
# cython: embedsignature=True # cython: embedsignature=True
from __future__ import unicode_literals from __future__ import unicode_literals
from os import path
import re import re
import pathlib
from cython.operator cimport dereference as deref from cython.operator cimport dereference as deref
from cython.operator cimport preincrement as preinc from cython.operator cimport preincrement as preinc
from cpython cimport Py_UNICODE_ISSPACE from cpython cimport Py_UNICODE_ISSPACE
try:
import ujson as json
except ImportError:
import json
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
@ -16,17 +23,53 @@ cimport cython
from . import util from . import util
from .tokens.doc cimport Doc from .tokens.doc cimport Doc
from .util import read_lang_data, get_package
cdef class Tokenizer: cdef class Tokenizer:
def __init__(self, Vocab vocab, rules, prefix_re, suffix_re, infix_re): @classmethod
def load(cls, path, Vocab vocab, rules=None, prefix_search=None, suffix_search=None,
infix_finditer=None):
'''Load a Tokenizer, reading unsupplied components from the path.
Arguments:
path pathlib.Path (or string, or Path-like)
vocab Vocab
rules dict
prefix_search callable -- Signature of re.compile(string).search
suffix_search callable -- Signature of re.compile(string).search
infix_finditer callable -- Signature of re.compile(string).finditer
'''
if isinstance(path, basestring):
path = pathlib.Path(path)
if rules is None:
with (path / 'tokenizer' / 'specials.json').open() as file_:
rules = json.load(file_)
if prefix_search is None:
prefix_search = util.read_regex(path / 'tokenizer' / 'prefix.txt').search
if suffix_search is None:
suffix_search = util.read_regex(path / 'tokenizer' / 'suffix.txt').search
if infix_finditer is None:
infix_finditer = util.read_regex(path / 'tokenizer' / 'infix.txt').finditer
return cls(vocab, rules, prefix_search, suffix_search, infix_finditer)
def __init__(self, Vocab vocab, rules, prefix_search, suffix_search, infix_finditer):
'''Create a Tokenizer, to create Doc objects given unicode text.
Arguments:
vocab Vocab
rules dict
prefix_search callable -- Signature of re.compile(string).search
suffix_search callable -- Signature of re.compile(string).search
infix_finditer callable -- Signature of re.compile(string).finditer
'''
self.mem = Pool() self.mem = Pool()
self._cache = PreshMap() self._cache = PreshMap()
self._specials = PreshMap() self._specials = PreshMap()
self._prefix_re = prefix_re self.prefix_search = prefix_search
self._suffix_re = suffix_re self.suffix_search = suffix_search
self._infix_re = infix_re self.infix_finditer = infix_finditer
self.vocab = vocab self.vocab = vocab
self._rules = {} self._rules = {}
for chunk, substrings in sorted(rules.items()): for chunk, substrings in sorted(rules.items()):
@ -40,19 +83,7 @@ cdef class Tokenizer:
self._infix_re) self._infix_re)
return (self.__class__, args, None, None) return (self.__class__, args, None, None)
@classmethod
def load(cls, data_dir, Vocab vocab):
return cls.from_package(get_package(data_dir), vocab=vocab)
@classmethod
def from_package(cls, package, Vocab vocab):
rules, prefix_re, suffix_re, infix_re = read_lang_data(package)
prefix_re = re.compile(prefix_re)
suffix_re = re.compile(suffix_re)
infix_re = re.compile(infix_re)
return cls(vocab, rules, prefix_re, suffix_re, infix_re)
cpdef Doc tokens_from_list(self, list strings): cpdef Doc tokens_from_list(self, list strings):
cdef Doc tokens = Doc(self.vocab) cdef Doc tokens = Doc(self.vocab)
if sum([len(s) for s in strings]) == 0: if sum([len(s) for s in strings]) == 0:
@ -258,14 +289,14 @@ cdef class Tokenizer:
self._cache.set(key, cached) self._cache.set(key, cached)
def find_infix(self, unicode string): def find_infix(self, unicode string):
return list(self._infix_re.finditer(string)) return list(self.infix_finditer(string))
def find_prefix(self, unicode string): def find_prefix(self, unicode string):
match = self._prefix_re.search(string) match = self.prefix_search(string)
return (match.end() - match.start()) if match is not None else 0 return (match.end() - match.start()) if match is not None else 0
def find_suffix(self, unicode string): def find_suffix(self, unicode string):
match = self._suffix_re.search(string) match = self.suffix_search(string)
return (match.end() - match.start()) if match is not None else 0 return (match.end() - match.start()) if match is not None else 0
def _load_special_tokenization(self, special_cases): def _load_special_tokenization(self, special_cases):

View File

@ -5,12 +5,6 @@ import re
import os.path import os.path
import six import six
import sputnik
from sputnik.dir_package import DirPackage
from sputnik.package_list import (PackageNotFoundException,
CompatiblePackageNotFoundException)
from . import about
from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
@ -29,30 +23,6 @@ def get_lang_class(name):
return LANGUAGES[lang] return LANGUAGES[lang]
def get_package(data_dir):
if not isinstance(data_dir, six.string_types):
raise RuntimeError('data_dir must be a string')
return DirPackage(data_dir)
def get_package_by_name(name=None, via=None):
if name is None:
return
lang = get_lang_class(name)
try:
return sputnik.package(about.__title__, about.__version__,
name, data_path=via)
except PackageNotFoundException as e:
raise RuntimeError("Model '%s' not installed. Please run 'python -m "
"%s.download' to install latest compatible "
"model." % (name, lang.__module__))
except CompatiblePackageNotFoundException as e:
raise RuntimeError("Installed model is not compatible with spaCy "
"version. Please run 'python -m %s.download "
"--force' to install latest compatible model." %
(lang.__module__))
def normalize_slice(length, start, stop, step=None): def normalize_slice(length, start, stop, step=None):
if not (step is None or step == 1): if not (step is None or step == 1):
raise ValueError("Stepped slices not supported in Span objects." raise ValueError("Stepped slices not supported in Span objects."
@ -75,100 +45,3 @@ def normalize_slice(length, start, stop, step=None):
def utf8open(loc, mode='r'): def utf8open(loc, mode='r'):
return io.open(loc, mode, encoding='utf8') return io.open(loc, mode, encoding='utf8')
def read_lang_data(package):
tokenization = package.load_json(('tokenizer', 'specials.json'))
with package.open(('tokenizer', 'prefix.txt'), default=None) as file_:
prefix = read_prefix(file_) if file_ is not None else None
with package.open(('tokenizer', 'suffix.txt'), default=None) as file_:
suffix = read_suffix(file_) if file_ is not None else None
with package.open(('tokenizer', 'infix.txt'), default=None) as file_:
infix = read_infix(file_) if file_ is not None else None
return tokenization, prefix, suffix, infix
def read_prefix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
return expression
def read_suffix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
return expression
def read_infix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join([piece for piece in entries if piece.strip()])
return expression
# def read_tokenization(lang):
# loc = path.join(DATA_DIR, lang, 'tokenization')
# entries = []
# seen = set()
# with utf8open(loc) as file_:
# for line in file_:
# line = line.strip()
# if line.startswith('#'):
# continue
# if not line:
# continue
# pieces = line.split()
# chunk = pieces.pop(0)
# assert chunk not in seen, chunk
# seen.add(chunk)
# entries.append((chunk, list(pieces)))
# if chunk[0].isalpha() and chunk[0].islower():
# chunk = chunk[0].title() + chunk[1:]
# pieces[0] = pieces[0][0].title() + pieces[0][1:]
# seen.add(chunk)
# entries.append((chunk, pieces))
# return entries
# def read_detoken_rules(lang): # Deprecated?
# loc = path.join(DATA_DIR, lang, 'detokenize')
# entries = []
# with utf8open(loc) as file_:
# for line in file_:
# entries.append(line.strip())
# return entries
def align_tokens(ref, indices): # Deprecated, surely?
start = 0
queue = list(indices)
for token in ref:
end = start + len(token)
emit = []
while queue and queue[0][1] <= end:
emit.append(queue.pop(0))
yield token, emit
start = end
assert not queue
def detokenize(token_rules, words): # Deprecated?
"""To align with treebanks, return a list of "chunks", where a chunk is a
sequence of tokens that are separated by whitespace in actual strings. Each
chunk should be a tuple of token indices, e.g.
>>> detokenize(["ca<SEP>n't", '<SEP>!'], ["I", "ca", "n't", "!"])
[(0,), (1, 2, 3)]
"""
string = ' '.join(words)
for subtoks in token_rules:
# Algorithmically this is dumb, but writing a little list-based match
# machine? Ain't nobody got time for that.
string = string.replace(subtoks.replace('<SEP>', ' '), subtoks)
positions = []
i = 0
for chunk in string.split():
subtoks = chunk.split('<SEP>')
positions.append(tuple(range(i, i+len(subtoks))))
i += len(subtoks)
return positions

View File

@ -19,7 +19,6 @@ from .orth cimport word_shape
from .typedefs cimport attr_t from .typedefs cimport attr_t
from .cfile cimport CFile from .cfile cimport CFile
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .util import get_package
from . import attrs from . import attrs
from . import symbols from . import symbols
@ -28,6 +27,7 @@ from cymem.cymem cimport Address
from .serialize.packer cimport Packer from .serialize.packer cimport Packer
from .attrs cimport PROB, LANG from .attrs cimport PROB, LANG
try: try:
import copy_reg import copy_reg
except ImportError: except ImportError:
@ -47,30 +47,32 @@ cdef class Vocab:
'''A map container for a language's LexemeC structs. '''A map container for a language's LexemeC structs.
''' '''
@classmethod @classmethod
def load(cls, data_dir, get_lex_attr=None): def load(cls, path, get_lex_attr=None, vectors=True, lemmatizer=None):
return cls.from_package(get_package(data_dir), get_lex_attr=get_lex_attr) if (path / 'vocab' / 'tag_map.json').exists():
with (path / 'vocab' / 'tag_map.json').open() as file_:
tag_map = json.loads(file_)
else:
tag_map = {}
@classmethod if lemmatizer is None:
def from_package(cls, package, get_lex_attr=None, vectors_package=None): lemmatizer = Lemmatizer.load(path)
tag_map = package.load_json(('vocab', 'tag_map.json'), default={})
lemmatizer = Lemmatizer.from_package(package) if (path / 'vocab' / 'serializer.json').exists():
with (path / 'vocab' / 'serializer.json').open() as file_:
serializer_freqs = package.load_json(('vocab', 'serializer.json'), default={}) serializer_freqs = json.loads(file_)
else:
serializer_freqs = {}
cdef Vocab self = cls(get_lex_attr=get_lex_attr, tag_map=tag_map, cdef Vocab self = cls(get_lex_attr=get_lex_attr, tag_map=tag_map,
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs) lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
with package.open(('vocab', 'strings.json')) as file_: with (path / 'vocab' / 'strings.json').open() as file_:
self.strings.load(file_) self.strings.load(file_)
self.load_lexemes(package.file_path('vocab', 'lexemes.bin')) self.load_lexemes(path / 'vocab' / 'lexemes.bin')
if vectors_package and vectors_package.has_file('vocab', 'vec.bin'): if vectors is True:
self.vectors_length = self.load_vectors_from_bin_loc( vectors = lambda self_: self_.load_vectors_from_bin_loc(path / 'vocab' / 'vec.bin')
vectors_package.file_path('vocab', 'vec.bin')) self.vectors_length = vectors(self)
elif package.has_file('vocab', 'vec.bin'):
self.vectors_length = self.load_vectors_from_bin_loc(
package.file_path('vocab', 'vec.bin'))
return self return self
def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None): def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None):
@ -87,6 +89,9 @@ cdef class Vocab:
# is the frequency rank of the word, plus a certain offset. The structural # is the frequency rank of the word, plus a certain offset. The structural
# strings are loaded first, because the vocab is open-class, and these # strings are loaded first, because the vocab is open-class, and these
# symbols are closed class. # symbols are closed class.
# TODO: Actually this has turned out to be a pain in the ass...
# It means the data is invalidated when we add a symbol :(
# Need to rethink this.
for name in symbols.NAMES + list(sorted(tag_map.keys())): for name in symbols.NAMES + list(sorted(tag_map.keys())):
if name: if name:
_ = self.strings[name] _ = self.strings[name]