diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index 5f06fd895..9eab7b54d 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -256,6 +256,7 @@ def add_vectors( def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int): f = open_file(vectors_loc) + f = ensure_shape(f) shape = tuple(int(size) for size in next(f).split()) if truncate_vectors >= 1: shape = (truncate_vectors, shape[1]) @@ -274,6 +275,31 @@ def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int): return vectors_data, vectors_keys +def ensure_shape(lines): + """Ensure that the first line of the data is the vectors shape. + + If it's not, we read in the data and output the shape as the first result, + so that the reader doesn't have to deal with the problem. + """ + first_line = next(lines) + try: + shape = tuple(int(size) for size in first_line.split()) + except ValueError: + shape = None + if shape is not None: + # All good, give the data + yield first_line + yield from lines + else: + # Figure out the shape, make it the first value, and then give the + # rest of the data. + width = len(first_line.split()) - 1 + captured = [first_line] + list(lines) + length = len(captured) + yield f"{length} {width}" + yield from captured + + def read_freqs( freqs_loc: Path, max_length: int = 100, min_doc_freq: int = 5, min_freq: int = 50 ):