* Fix oov probability

This commit is contained in:
Matthew Honnibal 2016-02-06 15:13:55 +01:00
parent af8514cb0c
commit a95974ad3f
2 changed files with 21 additions and 21 deletions

View File

@ -46,10 +46,6 @@ class Language(object):
def suffix(string):
return string[-3:]
@staticmethod
def prob(string):
return -30
@staticmethod
def cluster(string):
return 0
@ -119,7 +115,8 @@ class Language(object):
return 0
@classmethod
def default_lex_attrs(cls):
def default_lex_attrs(cls, *args, **kwargs):
oov_prob = kwargs.get('oov_prob', -20)
return {
attrs.LOWER: cls.lower,
attrs.NORM: cls.norm,
@ -127,8 +124,7 @@ class Language(object):
attrs.PREFIX: cls.prefix,
attrs.SUFFIX: cls.suffix,
attrs.CLUSTER: cls.cluster,
attrs.PROB: lambda string: -10.0,
attrs.PROB: lambda string: oov_prob,
attrs.IS_ALPHA: cls.is_alpha,
attrs.IS_ASCII: cls.is_ascii,
attrs.IS_DIGIT: cls.is_digit,
@ -159,7 +155,12 @@ class Language(object):
@classmethod
def default_vocab(cls, package, get_lex_attr=None):
if get_lex_attr is None:
get_lex_attr = cls.default_lex_attrs()
if package.has_file('vocab', 'oov_prob'):
with package.open(('vocab', 'oov_prob')) as file_:
oov_prob = float(file_.read().strip())
get_lex_attr = cls.default_lex_attrs(oov_prob=oov_prob)
else:
get_lex_attr = cls.default_lex_attrs()
if hasattr(package, 'dir_path'):
return Vocab.from_package(package, get_lex_attr=get_lex_attr)
else:

View File

@ -110,21 +110,20 @@ cdef class Vocab:
# TODO: This is hopelessly broken. The state is transferred as just
# a temp directory! We then fail to clean this up. This method therefore
# only pretends to work. What we need to do is form an archive file.
raise NotImplementedError
#tmp_dir = tempfile.mkdtemp()
#lex_loc = path.join(tmp_dir, 'lexemes.bin')
#str_loc = path.join(tmp_dir, 'strings.json')
#vec_loc = path.join(tmp_dir, 'vec.bin')
tmp_dir = tempfile.mkdtemp()
lex_loc = path.join(tmp_dir, 'lexemes.bin')
str_loc = path.join(tmp_dir, 'strings.json')
vec_loc = path.join(tmp_dir, 'vec.bin')
#self.dump(lex_loc)
#with io.open(str_loc, 'w', encoding='utf8') as file_:
# self.strings.dump(file_)
self.dump(lex_loc)
with io.open(str_loc, 'w', encoding='utf8') as file_:
self.strings.dump(file_)
#self.dump_vectors(vec_loc)
#
#state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
# self.serializer_freqs, self.data_dir)
#return (unpickle_vocab, state, None, None)
self.dump_vectors(vec_loc)
state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
self.serializer_freqs, self.data_dir)
return (unpickle_vocab, state, None, None)
cdef const LexemeC* get(self, Pool mem, unicode string) except NULL:
'''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme