From a95974ad3fb942a5ce8333f59d6cbc4969e9dbc5 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 6 Feb 2016 15:13:55 +0100 Subject: [PATCH] * Fix oov probability --- spacy/language.py | 17 +++++++++-------- spacy/vocab.pyx | 25 ++++++++++++------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index e85854735..ae8aa4560 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -46,10 +46,6 @@ class Language(object): def suffix(string): return string[-3:] - @staticmethod - def prob(string): - return -30 - @staticmethod def cluster(string): return 0 @@ -119,7 +115,8 @@ class Language(object): return 0 @classmethod - def default_lex_attrs(cls): + def default_lex_attrs(cls, *args, **kwargs): + oov_prob = kwargs.get('oov_prob', -20) return { attrs.LOWER: cls.lower, attrs.NORM: cls.norm, @@ -127,8 +124,7 @@ class Language(object): attrs.PREFIX: cls.prefix, attrs.SUFFIX: cls.suffix, attrs.CLUSTER: cls.cluster, - attrs.PROB: lambda string: -10.0, - + attrs.PROB: lambda string: oov_prob, attrs.IS_ALPHA: cls.is_alpha, attrs.IS_ASCII: cls.is_ascii, attrs.IS_DIGIT: cls.is_digit, @@ -159,7 +155,12 @@ class Language(object): @classmethod def default_vocab(cls, package, get_lex_attr=None): if get_lex_attr is None: - get_lex_attr = cls.default_lex_attrs() + if package.has_file('vocab', 'oov_prob'): + with package.open(('vocab', 'oov_prob')) as file_: + oov_prob = float(file_.read().strip()) + get_lex_attr = cls.default_lex_attrs(oov_prob=oov_prob) + else: + get_lex_attr = cls.default_lex_attrs() if hasattr(package, 'dir_path'): return Vocab.from_package(package, get_lex_attr=get_lex_attr) else: diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 5946d8169..a0a07f305 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -110,21 +110,20 @@ cdef class Vocab: # TODO: This is hopelessly broken. The state is transferred as just # a temp directory! We then fail to clean this up. This method therefore # only pretends to work. What we need to do is form an archive file. - raise NotImplementedError - #tmp_dir = tempfile.mkdtemp() - #lex_loc = path.join(tmp_dir, 'lexemes.bin') - #str_loc = path.join(tmp_dir, 'strings.json') - #vec_loc = path.join(tmp_dir, 'vec.bin') + tmp_dir = tempfile.mkdtemp() + lex_loc = path.join(tmp_dir, 'lexemes.bin') + str_loc = path.join(tmp_dir, 'strings.json') + vec_loc = path.join(tmp_dir, 'vec.bin') - #self.dump(lex_loc) - #with io.open(str_loc, 'w', encoding='utf8') as file_: - # self.strings.dump(file_) + self.dump(lex_loc) + with io.open(str_loc, 'w', encoding='utf8') as file_: + self.strings.dump(file_) - #self.dump_vectors(vec_loc) - # - #state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr, - # self.serializer_freqs, self.data_dir) - #return (unpickle_vocab, state, None, None) + self.dump_vectors(vec_loc) + + state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr, + self.serializer_freqs, self.data_dir) + return (unpickle_vocab, state, None, None) cdef const LexemeC* get(self, Pool mem, unicode string) except NULL: '''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme