mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Add profile function
This commit is contained in:
parent
5e50a65252
commit
7be5f30f17
|
@ -2,6 +2,7 @@ from .download import download
|
|||
from .info import info
|
||||
from .link import link
|
||||
from .package import package
|
||||
from .profile import profile
|
||||
from .train import train
|
||||
from .convert import convert
|
||||
from .model import model
|
||||
|
|
45
spacy/cli/profile.py
Normal file
45
spacy/cli/profile.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals, division, print_function
|
||||
|
||||
import plac
|
||||
from pathlib import Path
|
||||
import ujson
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
import spacy
|
||||
import sys
|
||||
import tqdm
|
||||
import cytoolz
|
||||
|
||||
|
||||
def read_inputs(loc):
|
||||
if loc is None:
|
||||
file_ = sys.stdin
|
||||
file_ = (line.encode('utf8') for line in file_)
|
||||
else:
|
||||
file_ = Path(loc).open()
|
||||
for line in file_:
|
||||
data = ujson.loads(line)
|
||||
text = data['text']
|
||||
yield text
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
lang=("model/language", "positional", None, str),
|
||||
inputs=("Location of input file", "positional", None, read_inputs)
|
||||
)
|
||||
def profile(cmd, lang, inputs=None):
|
||||
"""
|
||||
Profile a spaCy pipeline, to find out which functions take the most time.
|
||||
"""
|
||||
nlp = spacy.load(lang)
|
||||
texts = list(cytoolz.take(10000, inputs))
|
||||
cProfile.runctx("parse_texts(nlp, texts)", globals(), locals(), "Profile.prof")
|
||||
s = pstats.Stats("Profile.prof")
|
||||
s.strip_dirs().sort_stats("time").print_stats()
|
||||
|
||||
|
||||
def parse_texts(nlp, texts):
|
||||
for doc in nlp.pipe(tqdm.tqdm(texts), batch_size=128):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user