spaCy/spacy/cli/profile.py

51 lines
1.3 KiB
Python
Raw Normal View History

2017-08-22 00:22:49 +03:00
# coding: utf8
from __future__ import unicode_literals, division, print_function
import plac
from pathlib import Path
import ujson
import cProfile
import pstats
import spacy
import sys
import tqdm
import cytoolz
2017-11-15 15:51:25 +03:00
import thinc.extra.datasets
2017-08-22 00:22:49 +03:00
def read_inputs(loc):
if loc is None:
file_ = sys.stdin
file_ = (line.encode('utf8') for line in file_)
else:
file_ = Path(loc).open()
for line in file_:
data = ujson.loads(line)
text = data['text']
yield text
@plac.annotations(
lang=("model/language", "positional", None, str),
2017-10-27 15:38:39 +03:00
inputs=("Location of input file", "positional", None, read_inputs))
def profile(lang, inputs=None):
2017-08-22 00:22:49 +03:00
"""
Profile a spaCy pipeline, to find out which functions take the most time.
"""
2017-11-15 15:51:25 +03:00
if inputs is None:
imdb_train, _ = thinc.extra.datasets.imdb()
inputs, _ = zip(*imdb_train)
2017-11-17 21:13:00 +03:00
inputs = inputs[:25000]
2017-10-27 15:38:39 +03:00
nlp = spacy.load(lang)
2017-08-22 00:22:49 +03:00
texts = list(cytoolz.take(10000, inputs))
2017-10-27 15:38:39 +03:00
cProfile.runctx("parse_texts(nlp, texts)", globals(), locals(),
"Profile.prof")
2017-08-22 00:22:49 +03:00
s = pstats.Stats("Profile.prof")
2017-11-17 21:13:00 +03:00
s.strip_dirs().sort_stats("time").print_stats()
2017-08-22 00:22:49 +03:00
def parse_texts(nlp, texts):
2017-11-15 15:51:25 +03:00
for doc in nlp.pipe(tqdm.tqdm(texts), batch_size=16):
2017-08-22 00:22:49 +03:00
pass