pulling tqdm imports in functions to avoid bug (tmp fix) (#4263)

This commit is contained in:
Sofie Van Landeghem 2019-09-09 16:32:11 +02:00 committed by Matthew Honnibal
parent 25aecd504f
commit 482c7cd1b9
8 changed files with 38 additions and 10 deletions

View File

@ -5,7 +5,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import plac import plac
import tqdm
from pathlib import Path from pathlib import Path
import re import re
import sys import sys

View File

@ -5,7 +5,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import plac import plac
import tqdm
from pathlib import Path from pathlib import Path
import re import re
import sys import sys
@ -462,6 +461,9 @@ def main(
vectors_dir=None, vectors_dir=None,
use_oracle_segments=False, use_oracle_segments=False,
): ):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
spacy.util.fix_random_seed() spacy.util.fix_random_seed()
lang.zh.Chinese.Defaults.use_jieba = False lang.zh.Chinese.Defaults.use_jieba = False
lang.ja.Japanese.Defaults.use_janome = False lang.ja.Japanese.Defaults.use_janome = False

View File

@ -3,11 +3,9 @@
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import plac import plac
import tqdm
import attr import attr
from pathlib import Path from pathlib import Path
import re import re
import sys
import json import json
import spacy import spacy
@ -23,7 +21,7 @@ import itertools
import random import random
import numpy.random import numpy.random
import conll17_ud_eval from bin.ud import conll17_ud_eval
import spacy.lang.zh import spacy.lang.zh
import spacy.lang.ja import spacy.lang.ja
@ -394,6 +392,9 @@ class TreebankPaths(object):
limit=("Size limit", "option", "n", int), limit=("Size limit", "option", "n", int),
) )
def main(ud_dir, parses_dir, config, corpus, limit=0): def main(ud_dir, parses_dir, config, corpus, limit=0):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
paths = TreebankPaths(ud_dir, corpus) paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists(): if not (parses_dir / corpus).exists():
(parses_dir / corpus).mkdir() (parses_dir / corpus).mkdir()

View File

@ -18,7 +18,6 @@ import random
import spacy import spacy
import thinc.extra.datasets import thinc.extra.datasets
from spacy.util import minibatch, use_gpu, compounding from spacy.util import minibatch, use_gpu, compounding
import tqdm
from spacy._ml import Tok2Vec from spacy._ml import Tok2Vec
from spacy.pipeline import TextCategorizer from spacy.pipeline import TextCategorizer
import numpy import numpy
@ -107,6 +106,9 @@ def create_pipeline(width, embed_size, vectors_model):
def train_tensorizer(nlp, texts, dropout, n_iter): def train_tensorizer(nlp, texts, dropout, n_iter):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
tensorizer = nlp.create_pipe("tensorizer") tensorizer = nlp.create_pipe("tensorizer")
nlp.add_pipe(tensorizer) nlp.add_pipe(tensorizer)
optimizer = nlp.begin_training() optimizer = nlp.begin_training()
@ -120,6 +122,9 @@ def train_tensorizer(nlp, texts, dropout, n_iter):
def train_textcat(nlp, n_texts, n_iter=10): def train_textcat(nlp, n_texts, n_iter=10):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
textcat = nlp.get_pipe("textcat") textcat = nlp.get_pipe("textcat")
tok2vec_weights = textcat.model.tok2vec.to_bytes() tok2vec_weights = textcat.model.tok2vec.to_bytes()
(train_texts, train_cats), (dev_texts, dev_cats) = load_textcat_data(limit=n_texts) (train_texts, train_cats), (dev_texts, dev_cats) = load_textcat_data(limit=n_texts)

View File

@ -13,7 +13,6 @@ import numpy
import plac import plac
import spacy import spacy
import tensorflow as tf import tensorflow as tf
import tqdm
from tensorflow.contrib.tensorboard.plugins.projector import ( from tensorflow.contrib.tensorboard.plugins.projector import (
visualize_embeddings, visualize_embeddings,
ProjectorConfig, ProjectorConfig,
@ -36,6 +35,9 @@ from tensorflow.contrib.tensorboard.plugins.projector import (
), ),
) )
def main(vectors_loc, out_loc, name="spaCy_vectors"): def main(vectors_loc, out_loc, name="spaCy_vectors"):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
meta_file = "{}.tsv".format(name) meta_file = "{}.tsv".format(name)
out_meta_file = path.join(out_loc, meta_file) out_meta_file = path.join(out_loc, meta_file)

View File

@ -3,7 +3,6 @@ from __future__ import unicode_literals
import plac import plac
import math import math
from tqdm import tqdm
import numpy import numpy
from ast import literal_eval from ast import literal_eval
from pathlib import Path from pathlib import Path
@ -109,6 +108,9 @@ def open_file(loc):
def read_attrs_from_deprecated(freqs_loc, clusters_loc): def read_attrs_from_deprecated(freqs_loc, clusters_loc):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm
if freqs_loc is not None: if freqs_loc is not None:
with msg.loading("Counting frequencies..."): with msg.loading("Counting frequencies..."):
probs, _ = read_freqs(freqs_loc) probs, _ = read_freqs(freqs_loc)
@ -186,6 +188,9 @@ def add_vectors(nlp, vectors_loc, prune_vectors):
def read_vectors(vectors_loc): def read_vectors(vectors_loc):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm
f = open_file(vectors_loc) f = open_file(vectors_loc)
shape = tuple(int(size) for size in next(f).split()) shape = tuple(int(size) for size in next(f).split())
vectors_data = numpy.zeros(shape=shape, dtype="f") vectors_data = numpy.zeros(shape=shape, dtype="f")
@ -202,6 +207,9 @@ def read_vectors(vectors_loc):
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50): def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm
counts = PreshCounter() counts = PreshCounter()
total = 0 total = 0
with freqs_loc.open() as f: with freqs_loc.open() as f:
@ -231,6 +239,9 @@ def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
def read_clusters(clusters_loc): def read_clusters(clusters_loc):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm
clusters = {} clusters = {}
if ftfy is None: if ftfy is None:
user_warning(Warnings.W004) user_warning(Warnings.W004)

View File

@ -7,7 +7,6 @@ import srsly
import cProfile import cProfile
import pstats import pstats
import sys import sys
import tqdm
import itertools import itertools
import thinc.extra.datasets import thinc.extra.datasets
from wasabi import Printer from wasabi import Printer
@ -48,6 +47,9 @@ def profile(model, inputs=None, n_texts=10000):
def parse_texts(nlp, texts): def parse_texts(nlp, texts):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
for doc in nlp.pipe(tqdm.tqdm(texts), batch_size=16): for doc in nlp.pipe(tqdm.tqdm(texts), batch_size=16):
pass pass

View File

@ -4,7 +4,6 @@ from __future__ import unicode_literals, division, print_function
import plac import plac
import os import os
from pathlib import Path from pathlib import Path
import tqdm
from thinc.neural._classes.model import Model from thinc.neural._classes.model import Model
from timeit import default_timer as timer from timeit import default_timer as timer
import shutil import shutil
@ -101,6 +100,10 @@ def train(
JSON format. To convert data from other formats, use the `spacy convert` JSON format. To convert data from other formats, use the `spacy convert`
command. command.
""" """
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
msg = Printer() msg = Printer()
util.fix_random_seed() util.fix_random_seed()
util.set_env_log(verbose) util.set_env_log(verbose)
@ -390,6 +393,9 @@ def _score_for_model(meta):
@contextlib.contextmanager @contextlib.contextmanager
def _create_progress_bar(total): def _create_progress_bar(total):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
if int(os.environ.get("LOG_FRIENDLY", 0)): if int(os.environ.get("LOG_FRIENDLY", 0)):
yield yield
else: else: