mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Merge pull request #1279 from oroszgy/model_cli_v2
Added vector loading to model cli
This commit is contained in:
commit
876f38c548
|
@ -1,33 +1,51 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import bz2
|
||||||
import gzip
|
import gzip
|
||||||
import math
|
import math
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import spacy
|
||||||
from preshed.counter import PreshCounter
|
from preshed.counter import PreshCounter
|
||||||
|
|
||||||
import spacy
|
|
||||||
from ..compat import fix_text
|
|
||||||
from .. import util
|
from .. import util
|
||||||
|
from ..compat import fix_text
|
||||||
|
|
||||||
|
|
||||||
def model(cmd, lang, model_dir, freqs_data, clusters_data, vectors_data):
|
def model(cmd, lang, model_dir, freqs_data, clusters_data, vectors_data,
|
||||||
|
min_doc_freq=5, min_word_freq=200):
|
||||||
model_path = Path(model_dir)
|
model_path = Path(model_dir)
|
||||||
freqs_path = Path(freqs_data)
|
freqs_path = Path(freqs_data)
|
||||||
clusters_path = Path(clusters_data) if clusters_data else None
|
clusters_path = Path(clusters_data) if clusters_data else None
|
||||||
vectors_path = Path(vectors_data) if vectors_data else None
|
vectors_path = Path(vectors_data) if vectors_data else None
|
||||||
|
|
||||||
check_dirs(freqs_path, clusters_path, vectors_path)
|
check_dirs(freqs_path, clusters_path, vectors_path)
|
||||||
# vocab = util.get_lang_class(lang).Defaults.create_vocab()
|
vocab = util.get_lang_class(lang).Defaults.create_vocab()
|
||||||
nlp = spacy.blank(lang)
|
nlp = spacy.blank(lang)
|
||||||
vocab = nlp.vocab
|
vocab = nlp.vocab
|
||||||
probs, oov_prob = read_probs(freqs_path)
|
probs, oov_prob = read_probs(
|
||||||
|
freqs_path, min_doc_freq=int(min_doc_freq), min_freq=int(min_doc_freq))
|
||||||
clusters = read_clusters(clusters_path) if clusters_path else {}
|
clusters = read_clusters(clusters_path) if clusters_path else {}
|
||||||
populate_vocab(vocab, clusters, probs, oov_prob)
|
populate_vocab(vocab, clusters, probs, oov_prob)
|
||||||
|
add_vectors(vocab, vectors_path)
|
||||||
create_model(model_path, nlp)
|
create_model(model_path, nlp)
|
||||||
|
|
||||||
|
|
||||||
|
def add_vectors(vocab, vectors_path):
|
||||||
|
with bz2.BZ2File(vectors_path.as_posix()) as f:
|
||||||
|
num_words, dim = next(f).split()
|
||||||
|
vocab.clear_vectors(int(dim))
|
||||||
|
for line in f:
|
||||||
|
word_w_vector = line.decode("utf8").strip().split(" ")
|
||||||
|
word = word_w_vector[0]
|
||||||
|
vector = np.array([float(val) for val in word_w_vector[1:]])
|
||||||
|
if word in vocab:
|
||||||
|
vocab.set_vector(word, vector)
|
||||||
|
|
||||||
|
|
||||||
def create_model(model_path, model):
|
def create_model(model_path, model):
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
model_path.mkdir()
|
model_path.mkdir()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user