Fix reading in GloVe vectors

This commit is contained in:
Matthew Honnibal 2020-09-12 17:31:18 +02:00
parent b41be87213
commit 37347830d4

View File

@ -256,6 +256,7 @@ def add_vectors(
def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int): def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int):
f = open_file(vectors_loc) f = open_file(vectors_loc)
f = ensure_shape(f)
shape = tuple(int(size) for size in next(f).split()) shape = tuple(int(size) for size in next(f).split())
if truncate_vectors >= 1: if truncate_vectors >= 1:
shape = (truncate_vectors, shape[1]) shape = (truncate_vectors, shape[1])
@ -274,6 +275,31 @@ def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int):
return vectors_data, vectors_keys return vectors_data, vectors_keys
def ensure_shape(lines):
"""Ensure that the first line of the data is the vectors shape.
If it's not, we read in the data and output the shape as the first result,
so that the reader doesn't have to deal with the problem.
"""
first_line = next(lines)
try:
shape = tuple(int(size) for size in first_line.split())
except ValueError:
shape = None
if shape is not None:
# All good, give the data
yield first_line
yield from lines
else:
# Figure out the shape, make it the first value, and then give the
# rest of the data.
width = len(first_line.split()) - 1
captured = [first_line] + list(lines)
length = len(captured)
yield f"{length} {width}"
yield from captured
def read_freqs( def read_freqs(
freqs_loc: Path, max_length: int = 100, min_doc_freq: int = 5, min_freq: int = 50 freqs_loc: Path, max_length: int = 100, min_doc_freq: int = 5, min_freq: int = 50
): ):