mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
273e96b63f
|
@ -94,7 +94,7 @@ def main(model=None, output_dir=None, n_iter=100):
|
|||
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'parser']
|
||||
with nlp.disable_pipes(*other_pipes): # only train parser
|
||||
optimizer = nlp.begin_training(lambda: [])
|
||||
optimizer = nlp.begin_training()
|
||||
for itn in range(n_iter):
|
||||
random.shuffle(TRAIN_DATA)
|
||||
losses = {}
|
||||
|
|
|
@ -87,7 +87,7 @@ def main(model=None, new_model_name='animal', output_dir=None, n_iter=50):
|
|||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
|
||||
with nlp.disable_pipes(*other_pipes): # only train NER
|
||||
random.seed(0)
|
||||
optimizer = nlp.begin_training(lambda: [])
|
||||
optimizer = nlp.begin_training()
|
||||
for itn in range(n_iter):
|
||||
losses = {}
|
||||
gold_parses = get_gold_parses(nlp.make_doc, TRAIN_DATA)
|
||||
|
|
|
@ -64,7 +64,7 @@ def main(model=None, output_dir=None, n_iter=1000):
|
|||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'parser']
|
||||
with nlp.disable_pipes(*other_pipes): # only train parser
|
||||
optimizer = nlp.begin_training(lambda: [])
|
||||
optimizer = nlp.begin_training()
|
||||
for itn in range(n_iter):
|
||||
random.shuffle(TRAIN_DATA)
|
||||
losses = {}
|
||||
|
|
|
@ -61,7 +61,7 @@ def main(lang='en', output_dir=None, n_iter=25):
|
|||
tagger = nlp.create_pipe('tagger')
|
||||
nlp.add_pipe(tagger)
|
||||
|
||||
optimizer = nlp.begin_training(lambda: [])
|
||||
optimizer = nlp.begin_training()
|
||||
for i in range(n_iter):
|
||||
random.shuffle(TRAIN_DATA)
|
||||
losses = {}
|
||||
|
|
|
@ -59,7 +59,7 @@ def main(model=None, output_dir=None, n_iter=20):
|
|||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'textcat']
|
||||
with nlp.disable_pipes(*other_pipes): # only train textcat
|
||||
optimizer = nlp.begin_training(lambda: [])
|
||||
optimizer = nlp.begin_training()
|
||||
print("Training the model...")
|
||||
print('{:^5}\t{:^5}\t{:^5}\t{:^5}'.format('LOSS', 'P', 'R', 'F'))
|
||||
for i in range(n_iter):
|
||||
|
|
|
@ -6,7 +6,7 @@ from __future__ import print_function
|
|||
if __name__ == '__main__':
|
||||
import plac
|
||||
import sys
|
||||
from spacy.cli import download, link, info, package, train, convert, model
|
||||
from spacy.cli import download, link, info, package, train, convert
|
||||
from spacy.cli import vocab, profile, evaluate, validate
|
||||
from spacy.util import prints
|
||||
|
||||
|
@ -18,7 +18,6 @@ if __name__ == '__main__':
|
|||
'evaluate': evaluate,
|
||||
'convert': convert,
|
||||
'package': package,
|
||||
'model': model,
|
||||
'vocab': vocab,
|
||||
'profile': profile,
|
||||
'validate': validate
|
||||
|
|
|
@ -6,6 +6,5 @@ from .profile import profile
|
|||
from .train import train
|
||||
from .evaluate import evaluate
|
||||
from .convert import convert
|
||||
from .model import model
|
||||
from .vocab import make_vocab as vocab
|
||||
from .validate import validate
|
||||
|
|
|
@ -1,140 +0,0 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
try:
|
||||
import bz2
|
||||
import gzip
|
||||
except ImportError:
|
||||
pass
|
||||
import math
|
||||
from ast import literal_eval
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import spacy
|
||||
from preshed.counter import PreshCounter
|
||||
|
||||
from .. import util
|
||||
from ..compat import fix_text
|
||||
|
||||
|
||||
def model(cmd, lang, model_dir, freqs_data, clusters_data, vectors_data,
|
||||
min_doc_freq=5, min_word_freq=200):
|
||||
model_path = Path(model_dir)
|
||||
freqs_path = Path(freqs_data)
|
||||
clusters_path = Path(clusters_data) if clusters_data else None
|
||||
vectors_path = Path(vectors_data) if vectors_data else None
|
||||
|
||||
check_dirs(freqs_path, clusters_path, vectors_path)
|
||||
vocab = util.get_lang_class(lang).Defaults.create_vocab()
|
||||
nlp = spacy.blank(lang)
|
||||
vocab = nlp.vocab
|
||||
probs, oov_prob = read_probs(
|
||||
freqs_path, min_doc_freq=int(min_doc_freq), min_freq=int(min_doc_freq))
|
||||
clusters = read_clusters(clusters_path) if clusters_path else {}
|
||||
populate_vocab(vocab, clusters, probs, oov_prob)
|
||||
add_vectors(vocab, vectors_path)
|
||||
create_model(model_path, nlp)
|
||||
|
||||
|
||||
def add_vectors(vocab, vectors_path):
|
||||
with bz2.BZ2File(vectors_path.as_posix()) as f:
|
||||
num_words, dim = next(f).split()
|
||||
vocab.clear_vectors(int(dim))
|
||||
for line in f:
|
||||
word_w_vector = line.decode("utf8").strip().split(" ")
|
||||
word = word_w_vector[0]
|
||||
vector = np.array([float(val) for val in word_w_vector[1:]])
|
||||
if word in vocab:
|
||||
vocab.set_vector(word, vector)
|
||||
|
||||
|
||||
def create_model(model_path, model):
|
||||
if not model_path.exists():
|
||||
model_path.mkdir()
|
||||
model.to_disk(model_path.as_posix())
|
||||
|
||||
|
||||
def read_probs(freqs_path, max_length=100, min_doc_freq=5, min_freq=200):
|
||||
counts = PreshCounter()
|
||||
total = 0
|
||||
freqs_file = check_unzip(freqs_path)
|
||||
for i, line in enumerate(freqs_file):
|
||||
freq, doc_freq, key = line.rstrip().split('\t', 2)
|
||||
freq = int(freq)
|
||||
counts.inc(i + 1, freq)
|
||||
total += freq
|
||||
counts.smooth()
|
||||
log_total = math.log(total)
|
||||
freqs_file = check_unzip(freqs_path)
|
||||
probs = {}
|
||||
for line in freqs_file:
|
||||
freq, doc_freq, key = line.rstrip().split('\t', 2)
|
||||
doc_freq = int(doc_freq)
|
||||
freq = int(freq)
|
||||
if doc_freq >= min_doc_freq and freq >= min_freq and len(
|
||||
key) < max_length:
|
||||
word = literal_eval(key)
|
||||
smooth_count = counts.smoother(int(freq))
|
||||
probs[word] = math.log(smooth_count) - log_total
|
||||
oov_prob = math.log(counts.smoother(0)) - log_total
|
||||
return probs, oov_prob
|
||||
|
||||
|
||||
def read_clusters(clusters_path):
|
||||
clusters = {}
|
||||
with clusters_path.open() as f:
|
||||
for line in f:
|
||||
try:
|
||||
cluster, word, freq = line.split()
|
||||
word = fix_text(word)
|
||||
except ValueError:
|
||||
continue
|
||||
# If the clusterer has only seen the word a few times, its
|
||||
# cluster is unreliable.
|
||||
if int(freq) >= 3:
|
||||
clusters[word] = cluster
|
||||
else:
|
||||
clusters[word] = '0'
|
||||
# Expand clusters with re-casing
|
||||
for word, cluster in list(clusters.items()):
|
||||
if word.lower() not in clusters:
|
||||
clusters[word.lower()] = cluster
|
||||
if word.title() not in clusters:
|
||||
clusters[word.title()] = cluster
|
||||
if word.upper() not in clusters:
|
||||
clusters[word.upper()] = cluster
|
||||
return clusters
|
||||
|
||||
|
||||
def populate_vocab(vocab, clusters, probs, oov_prob):
|
||||
for word, prob in reversed(
|
||||
sorted(list(probs.items()), key=lambda item: item[1])):
|
||||
lexeme = vocab[word]
|
||||
lexeme.prob = prob
|
||||
lexeme.is_oov = False
|
||||
# Decode as a little-endian string, so that we can do & 15 to get
|
||||
# the first 4 bits. See _parse_features.pyx
|
||||
if word in clusters:
|
||||
lexeme.cluster = int(clusters[word][::-1], 2)
|
||||
else:
|
||||
lexeme.cluster = 0
|
||||
|
||||
|
||||
def check_unzip(file_path):
|
||||
file_path_str = file_path.as_posix()
|
||||
if file_path_str.endswith('gz'):
|
||||
return gzip.open(file_path_str)
|
||||
else:
|
||||
return file_path.open()
|
||||
|
||||
|
||||
def check_dirs(freqs_data, clusters_data, vectors_data):
|
||||
if not freqs_data.is_file():
|
||||
util.sys_exit(freqs_data.as_posix(), title="No frequencies file found")
|
||||
if clusters_data and not clusters_data.is_file():
|
||||
util.sys_exit(
|
||||
clusters_data.as_posix(), title="No Brown clusters file found")
|
||||
if vectors_data and not vectors_data.is_file():
|
||||
util.sys_exit(
|
||||
vectors_data.as_posix(), title="No word vectors file found")
|
|
@ -436,8 +436,10 @@ class Language(object):
|
|||
**cfg: Config parameters.
|
||||
RETURNS: An optimizer
|
||||
"""
|
||||
if get_gold_tuples is None:
|
||||
get_gold_tuples = lambda: []
|
||||
# Populate vocab
|
||||
if get_gold_tuples is not None:
|
||||
else:
|
||||
for _, annots_brackets in get_gold_tuples():
|
||||
for annots, _ in annots_brackets:
|
||||
for word in annots[1]:
|
||||
|
|
|
@ -161,10 +161,16 @@ p
|
|||
+cell float64 (double)
|
||||
|
||||
+code.
|
||||
from spacy.vectors import Vectors
|
||||
nlp = spacy.load('en')
|
||||
nlp.vocab.vectors.from_glove('/path/to/vectors')
|
||||
|
||||
vectors = Vectors([], 128)
|
||||
vectors.from_glove('/path/to/vectors')
|
||||
p
|
||||
| If your instance of #[code Language] already contains vectors, they will
|
||||
| be overwritten. To create your own GloVe vectors model package like
|
||||
| spaCy's #[+a("/models/en#en_vectors_web_lg") #[code en_vectors_web_lg]],
|
||||
| you can call #[+api("language#to_disk") #[code nlp.to_disk]], and then
|
||||
| package the model using the #[+api("cli#package") #[code package]]
|
||||
| command.
|
||||
|
||||
+h(3, "custom-loading-other") Loading other vectors
|
||||
+tag-new(2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user