mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Pass OOV probability around
This commit is contained in:
parent
5b6bf4d4a6
commit
fd525f0675
|
@ -110,8 +110,8 @@ def _read_freqs(loc):
|
|||
smooth_count = counts.smoother(int(freq))
|
||||
log_smooth_count = math.log(smooth_count)
|
||||
probs[word] = math.log(smooth_count) - log_total
|
||||
probs['-OOV-'] = math.log(counts.smoother(0)) - log_total
|
||||
return probs
|
||||
oov_prob = math.log(counts.smoother(0)) - log_total
|
||||
return probs, oov_prob
|
||||
|
||||
|
||||
def _read_senses(loc):
|
||||
|
@ -144,29 +144,30 @@ def setup_vocab(src_dir, dst_dir):
|
|||
print("Warning: Word vectors file not found")
|
||||
vocab = Vocab(data_dir=None, get_lex_props=get_lex_props)
|
||||
clusters = _read_clusters(src_dir / 'clusters.txt')
|
||||
probs = _read_probs(src_dir / 'words.sgt.prob')
|
||||
probs, oov_prob = _read_probs(src_dir / 'words.sgt.prob')
|
||||
if not probs:
|
||||
probs = _read_freqs(src_dir / 'freqs.txt')
|
||||
probs, oov_prob = _read_freqs(src_dir / 'freqs.txt')
|
||||
if not probs:
|
||||
min_prob = 0.0
|
||||
oov_prob = 0.0
|
||||
else:
|
||||
min_prob = min(probs.values())
|
||||
oov_prob = min(probs.values())
|
||||
for word in clusters:
|
||||
if word not in probs:
|
||||
probs[word] = min_prob
|
||||
probs[word] = oov_prob
|
||||
|
||||
lexicon = []
|
||||
for word, prob in reversed(sorted(list(probs.items()), key=lambda item: item[1])):
|
||||
entry = get_lex_props(word)
|
||||
if word in clusters:
|
||||
entry['prob'] = float(prob)
|
||||
cluster = clusters.get(word, '0')
|
||||
# Decode as a little-endian string, so that we can do & 15 to get
|
||||
# the first 4 bits. See _parse_features.pyx
|
||||
entry['cluster'] = int(cluster[::-1], 2)
|
||||
vocab[word] = entry
|
||||
entry['prob'] = float(prob)
|
||||
cluster = clusters.get(word, '0')
|
||||
# Decode as a little-endian string, so that we can do & 15 to get
|
||||
# the first 4 bits. See _parse_features.pyx
|
||||
entry['cluster'] = int(cluster[::-1], 2)
|
||||
vocab[word] = entry
|
||||
vocab.dump(str(dst_dir / 'lexemes.bin'))
|
||||
vocab.strings.dump(str(dst_dir / 'strings.txt'))
|
||||
with (dst_dir / 'oov_prob').open('w') as file_:
|
||||
file_.write('%f' % oov_prob)
|
||||
|
||||
|
||||
def main(lang_data_dir, corpora_dir, model_dir):
|
||||
|
|
|
@ -31,6 +31,7 @@ cdef class Vocab:
|
|||
cdef readonly int length
|
||||
cdef public object _serializer
|
||||
cdef public object data_dir
|
||||
cdef public float oov_prob
|
||||
|
||||
cdef const LexemeC* get(self, Pool mem, unicode string) except NULL
|
||||
cdef const LexemeC* get_by_orth(self, Pool mem, attr_t orth) except NULL
|
||||
|
|
|
@ -37,7 +37,7 @@ cdef class Vocab:
|
|||
'''A map container for a language's LexemeC structs.
|
||||
'''
|
||||
def __init__(self, data_dir=None, get_lex_props=None, load_vectors=True,
|
||||
pos_tags=None):
|
||||
pos_tags=None, oov_prob=-30):
|
||||
self.mem = Pool()
|
||||
self._by_hash = PreshMap()
|
||||
self._by_orth = PreshMap()
|
||||
|
@ -61,6 +61,7 @@ cdef class Vocab:
|
|||
|
||||
self._serializer = None
|
||||
self.data_dir = data_dir
|
||||
self.oov_prob = oov_prob
|
||||
|
||||
property serializer:
|
||||
def __get__(self):
|
||||
|
@ -90,7 +91,7 @@ cdef class Vocab:
|
|||
if len(string) < 3:
|
||||
mem = self.mem
|
||||
lex = <LexemeC*>mem.alloc(sizeof(LexemeC), 1)
|
||||
props = self.lexeme_props_getter(string)
|
||||
props = self.lexeme_props_getter(string, self.oov_prob)
|
||||
set_lex_struct_props(lex, props, self.strings, EMPTY_VEC)
|
||||
if is_oov:
|
||||
lex.id = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user