From dad5621166955481d3c86ad81bb3726d06cb6ff3 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 31 Aug 2019 13:39:31 +0200 Subject: [PATCH] Tidy up and auto-format [ci skip] --- .../keras_parikh_entailment/spacy_hook.py | 30 +++++++++---------- examples/training/rehearsal.py | 4 +-- examples/training/train_textcat.py | 8 ++--- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/examples/keras_parikh_entailment/spacy_hook.py b/examples/keras_parikh_entailment/spacy_hook.py index 98c355738..307669a70 100644 --- a/examples/keras_parikh_entailment/spacy_hook.py +++ b/examples/keras_parikh_entailment/spacy_hook.py @@ -12,15 +12,15 @@ class KerasSimilarityShim(object): @classmethod def load(cls, path, nlp, max_length=100, get_features=None): - + if get_features is None: get_features = get_word_ids - - with (path / 'config.json').open() as file_: + + with (path / "config.json").open() as file_: model = model_from_json(file_.read()) - with (path / 'model').open('rb') as file_: + with (path / "model").open("rb") as file_: weights = pickle.load(file_) - + embeddings = get_embeddings(nlp.vocab) weights.insert(1, embeddings) model.set_weights(weights) @@ -33,8 +33,8 @@ class KerasSimilarityShim(object): self.max_length = max_length def __call__(self, doc): - doc.user_hooks['similarity'] = self.predict - doc.user_span_hooks['similarity'] = self.predict + doc.user_hooks["similarity"] = self.predict + doc.user_span_hooks["similarity"] = self.predict return doc @@ -48,24 +48,24 @@ class KerasSimilarityShim(object): def get_embeddings(vocab, nr_unk=100): # the extra +1 is for a zero vector representing sentence-final padding - num_vectors = max(lex.rank for lex in vocab) + 2 - + num_vectors = max(lex.rank for lex in vocab) + 2 + # create random vectors for OOV tokens oov = np.random.normal(size=(nr_unk, vocab.vectors_length)) oov = oov / oov.sum(axis=1, keepdims=True) - - vectors = np.zeros((num_vectors + nr_unk, vocab.vectors_length), dtype='float32') - vectors[1:(nr_unk + 1), ] = oov + + vectors = np.zeros((num_vectors + nr_unk, vocab.vectors_length), dtype="float32") + vectors[1 : (nr_unk + 1),] = oov for lex in vocab: if lex.has_vector and lex.vector_norm > 0: - vectors[nr_unk + lex.rank + 1] = lex.vector / lex.vector_norm + vectors[nr_unk + lex.rank + 1] = lex.vector / lex.vector_norm return vectors def get_word_ids(docs, max_length=100, nr_unk=100): - Xs = np.zeros((len(docs), max_length), dtype='int32') - + Xs = np.zeros((len(docs), max_length), dtype="int32") + for i, doc in enumerate(docs): for j, token in enumerate(doc): if j == max_length: diff --git a/examples/training/rehearsal.py b/examples/training/rehearsal.py index 21e897ced..123f5049d 100644 --- a/examples/training/rehearsal.py +++ b/examples/training/rehearsal.py @@ -80,7 +80,7 @@ def main(model_name, unlabelled_loc): nlp.rehearse(raw_batch, sgd=optimizer, losses=r_losses) print("Losses", losses) print("R. Losses", r_losses) - print(nlp.get_pipe('ner').model.unseen_classes) + print(nlp.get_pipe("ner").model.unseen_classes) test_text = "Do you like horses?" doc = nlp(test_text) print("Entities in '%s'" % test_text) @@ -88,7 +88,5 @@ def main(model_name, unlabelled_loc): print(ent.label_, ent.text) - - if __name__ == "__main__": plac.call(main) diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 7cd492f75..4d4ebf396 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -24,7 +24,7 @@ from spacy.util import minibatch, compounding output_dir=("Optional output directory", "option", "o", Path), n_texts=("Number of texts to train from", "option", "t", int), n_iter=("Number of training iterations", "option", "n", int), - init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path) + init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path), ) def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None): if output_dir is not None: @@ -43,11 +43,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None # nlp.create_pipe works for built-ins that are registered with spaCy if "textcat" not in nlp.pipe_names: textcat = nlp.create_pipe( - "textcat", - config={ - "exclusive_classes": True, - "architecture": "simple_cnn", - } + "textcat", config={"exclusive_classes": True, "architecture": "simple_cnn"} ) nlp.add_pipe(textcat, last=True) # otherwise, get it, so we can add labels to it