* Fix oov probability

This commit is contained in:
Matthew Honnibal 2016-02-06 15:13:55 +01:00
parent af8514cb0c
commit a95974ad3f
2 changed files with 21 additions and 21 deletions

View File

@ -46,10 +46,6 @@ class Language(object):
def suffix(string): def suffix(string):
return string[-3:] return string[-3:]
@staticmethod
def prob(string):
return -30
@staticmethod @staticmethod
def cluster(string): def cluster(string):
return 0 return 0
@ -119,7 +115,8 @@ class Language(object):
return 0 return 0
@classmethod @classmethod
def default_lex_attrs(cls): def default_lex_attrs(cls, *args, **kwargs):
oov_prob = kwargs.get('oov_prob', -20)
return { return {
attrs.LOWER: cls.lower, attrs.LOWER: cls.lower,
attrs.NORM: cls.norm, attrs.NORM: cls.norm,
@ -127,8 +124,7 @@ class Language(object):
attrs.PREFIX: cls.prefix, attrs.PREFIX: cls.prefix,
attrs.SUFFIX: cls.suffix, attrs.SUFFIX: cls.suffix,
attrs.CLUSTER: cls.cluster, attrs.CLUSTER: cls.cluster,
attrs.PROB: lambda string: -10.0, attrs.PROB: lambda string: oov_prob,
attrs.IS_ALPHA: cls.is_alpha, attrs.IS_ALPHA: cls.is_alpha,
attrs.IS_ASCII: cls.is_ascii, attrs.IS_ASCII: cls.is_ascii,
attrs.IS_DIGIT: cls.is_digit, attrs.IS_DIGIT: cls.is_digit,
@ -159,6 +155,11 @@ class Language(object):
@classmethod @classmethod
def default_vocab(cls, package, get_lex_attr=None): def default_vocab(cls, package, get_lex_attr=None):
if get_lex_attr is None: if get_lex_attr is None:
if package.has_file('vocab', 'oov_prob'):
with package.open(('vocab', 'oov_prob')) as file_:
oov_prob = float(file_.read().strip())
get_lex_attr = cls.default_lex_attrs(oov_prob=oov_prob)
else:
get_lex_attr = cls.default_lex_attrs() get_lex_attr = cls.default_lex_attrs()
if hasattr(package, 'dir_path'): if hasattr(package, 'dir_path'):
return Vocab.from_package(package, get_lex_attr=get_lex_attr) return Vocab.from_package(package, get_lex_attr=get_lex_attr)

View File

@ -110,21 +110,20 @@ cdef class Vocab:
# TODO: This is hopelessly broken. The state is transferred as just # TODO: This is hopelessly broken. The state is transferred as just
# a temp directory! We then fail to clean this up. This method therefore # a temp directory! We then fail to clean this up. This method therefore
# only pretends to work. What we need to do is form an archive file. # only pretends to work. What we need to do is form an archive file.
raise NotImplementedError tmp_dir = tempfile.mkdtemp()
#tmp_dir = tempfile.mkdtemp() lex_loc = path.join(tmp_dir, 'lexemes.bin')
#lex_loc = path.join(tmp_dir, 'lexemes.bin') str_loc = path.join(tmp_dir, 'strings.json')
#str_loc = path.join(tmp_dir, 'strings.json') vec_loc = path.join(tmp_dir, 'vec.bin')
#vec_loc = path.join(tmp_dir, 'vec.bin')
#self.dump(lex_loc) self.dump(lex_loc)
#with io.open(str_loc, 'w', encoding='utf8') as file_: with io.open(str_loc, 'w', encoding='utf8') as file_:
# self.strings.dump(file_) self.strings.dump(file_)
#self.dump_vectors(vec_loc) self.dump_vectors(vec_loc)
#
#state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr, state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
# self.serializer_freqs, self.data_dir) self.serializer_freqs, self.data_dir)
#return (unpickle_vocab, state, None, None) return (unpickle_vocab, state, None, None)
cdef const LexemeC* get(self, Pool mem, unicode string) except NULL: cdef const LexemeC* get(self, Pool mem, unicode string) except NULL:
'''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme '''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme