mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
* Fix oov probability
This commit is contained in:
parent
af8514cb0c
commit
a95974ad3f
|
@ -46,10 +46,6 @@ class Language(object):
|
|||
def suffix(string):
|
||||
return string[-3:]
|
||||
|
||||
@staticmethod
|
||||
def prob(string):
|
||||
return -30
|
||||
|
||||
@staticmethod
|
||||
def cluster(string):
|
||||
return 0
|
||||
|
@ -119,7 +115,8 @@ class Language(object):
|
|||
return 0
|
||||
|
||||
@classmethod
|
||||
def default_lex_attrs(cls):
|
||||
def default_lex_attrs(cls, *args, **kwargs):
|
||||
oov_prob = kwargs.get('oov_prob', -20)
|
||||
return {
|
||||
attrs.LOWER: cls.lower,
|
||||
attrs.NORM: cls.norm,
|
||||
|
@ -127,8 +124,7 @@ class Language(object):
|
|||
attrs.PREFIX: cls.prefix,
|
||||
attrs.SUFFIX: cls.suffix,
|
||||
attrs.CLUSTER: cls.cluster,
|
||||
attrs.PROB: lambda string: -10.0,
|
||||
|
||||
attrs.PROB: lambda string: oov_prob,
|
||||
attrs.IS_ALPHA: cls.is_alpha,
|
||||
attrs.IS_ASCII: cls.is_ascii,
|
||||
attrs.IS_DIGIT: cls.is_digit,
|
||||
|
@ -159,7 +155,12 @@ class Language(object):
|
|||
@classmethod
|
||||
def default_vocab(cls, package, get_lex_attr=None):
|
||||
if get_lex_attr is None:
|
||||
get_lex_attr = cls.default_lex_attrs()
|
||||
if package.has_file('vocab', 'oov_prob'):
|
||||
with package.open(('vocab', 'oov_prob')) as file_:
|
||||
oov_prob = float(file_.read().strip())
|
||||
get_lex_attr = cls.default_lex_attrs(oov_prob=oov_prob)
|
||||
else:
|
||||
get_lex_attr = cls.default_lex_attrs()
|
||||
if hasattr(package, 'dir_path'):
|
||||
return Vocab.from_package(package, get_lex_attr=get_lex_attr)
|
||||
else:
|
||||
|
|
|
@ -110,21 +110,20 @@ cdef class Vocab:
|
|||
# TODO: This is hopelessly broken. The state is transferred as just
|
||||
# a temp directory! We then fail to clean this up. This method therefore
|
||||
# only pretends to work. What we need to do is form an archive file.
|
||||
raise NotImplementedError
|
||||
#tmp_dir = tempfile.mkdtemp()
|
||||
#lex_loc = path.join(tmp_dir, 'lexemes.bin')
|
||||
#str_loc = path.join(tmp_dir, 'strings.json')
|
||||
#vec_loc = path.join(tmp_dir, 'vec.bin')
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
lex_loc = path.join(tmp_dir, 'lexemes.bin')
|
||||
str_loc = path.join(tmp_dir, 'strings.json')
|
||||
vec_loc = path.join(tmp_dir, 'vec.bin')
|
||||
|
||||
#self.dump(lex_loc)
|
||||
#with io.open(str_loc, 'w', encoding='utf8') as file_:
|
||||
# self.strings.dump(file_)
|
||||
self.dump(lex_loc)
|
||||
with io.open(str_loc, 'w', encoding='utf8') as file_:
|
||||
self.strings.dump(file_)
|
||||
|
||||
#self.dump_vectors(vec_loc)
|
||||
#
|
||||
#state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
|
||||
# self.serializer_freqs, self.data_dir)
|
||||
#return (unpickle_vocab, state, None, None)
|
||||
self.dump_vectors(vec_loc)
|
||||
|
||||
state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
|
||||
self.serializer_freqs, self.data_dir)
|
||||
return (unpickle_vocab, state, None, None)
|
||||
|
||||
cdef const LexemeC* get(self, Pool mem, unicode string) except NULL:
|
||||
'''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme
|
||||
|
|
Loading…
Reference in New Issue
Block a user