mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 20:28:20 +03:00
95a9615221
This patch addresses #1660, which was caused by keying all pre-trained vectors with the same ID when telling Thinc how to refer to them. This meant that if multiple models were loaded that had pre-trained vectors, errors or incorrect behaviour resulted. The vectors class now includes a .name attribute, which defaults to: {nlp.meta['lang']_nlp.meta['name']}.vectors The vectors name is set in the cfg of the pipeline components under the key pretrained_vectors. This replaces the previous cfg key pretrained_dims. In order to make existing models compatible with this change, we check for the pretrained_dims key when loading models in from_disk and from_bytes, and add the cfg key pretrained_vectors if we find it.
204 lines
5.1 KiB
Python
204 lines
5.1 KiB
Python
# coding: utf-8
|
||
from __future__ import unicode_literals
|
||
|
||
from io import StringIO, BytesIO
|
||
from pathlib import Path
|
||
import pytest
|
||
|
||
from .util import load_test_model
|
||
from ..tokens import Doc
|
||
from ..strings import StringStore
|
||
from .. import util
|
||
|
||
|
||
# These languages are used for generic tokenizer tests – only add a language
|
||
# here if it's using spaCy's tokenizer (not a different library)
|
||
# TODO: re-implement generic tokenizer tests
|
||
_languages = ['bn', 'da', 'de', 'en', 'es', 'fi', 'fr', 'ga', 'he', 'hu', 'id',
|
||
'it', 'nb', 'nl', 'pl', 'pt', 'ru', 'sv', 'tr', 'xx']
|
||
_models = {'en': ['en_core_web_sm'],
|
||
'de': ['de_core_news_md'],
|
||
'fr': ['fr_core_news_sm'],
|
||
'xx': ['xx_ent_web_md'],
|
||
'en_core_web_md': ['en_core_web_md'],
|
||
'es_core_news_md': ['es_core_news_md']}
|
||
|
||
|
||
# only used for tests that require loading the models
|
||
# in all other cases, use specific instances
|
||
|
||
@pytest.fixture(params=_models['en'])
|
||
def EN(request):
|
||
return load_test_model(request.param)
|
||
|
||
|
||
@pytest.fixture(params=_models['de'])
|
||
def DE(request):
|
||
return load_test_model(request.param)
|
||
|
||
|
||
@pytest.fixture(params=_models['fr'])
|
||
def FR(request):
|
||
return load_test_model(request.param)
|
||
|
||
|
||
@pytest.fixture()
|
||
def RU(request):
|
||
pymorphy = pytest.importorskip('pymorphy2')
|
||
return util.get_lang_class('ru')()
|
||
|
||
|
||
#@pytest.fixture(params=_languages)
|
||
#def tokenizer(request):
|
||
#lang = util.get_lang_class(request.param)
|
||
#return lang.Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def tokenizer():
|
||
return util.get_lang_class('xx').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def en_tokenizer():
|
||
return util.get_lang_class('en').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def en_vocab():
|
||
return util.get_lang_class('en').Defaults.create_vocab()
|
||
|
||
|
||
@pytest.fixture
|
||
def en_parser(en_vocab):
|
||
nlp = util.get_lang_class('en')(en_vocab)
|
||
return nlp.create_pipe('parser')
|
||
|
||
|
||
@pytest.fixture
|
||
def es_tokenizer():
|
||
return util.get_lang_class('es').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def de_tokenizer():
|
||
return util.get_lang_class('de').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def fr_tokenizer():
|
||
return util.get_lang_class('fr').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def hu_tokenizer():
|
||
return util.get_lang_class('hu').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def fi_tokenizer():
|
||
return util.get_lang_class('fi').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def id_tokenizer():
|
||
return util.get_lang_class('id').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def sv_tokenizer():
|
||
return util.get_lang_class('sv').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def bn_tokenizer():
|
||
return util.get_lang_class('bn').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def ga_tokenizer():
|
||
return util.get_lang_class('ga').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def he_tokenizer():
|
||
return util.get_lang_class('he').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def nb_tokenizer():
|
||
return util.get_lang_class('nb').Defaults.create_tokenizer()
|
||
|
||
@pytest.fixture
|
||
def da_tokenizer():
|
||
return util.get_lang_class('da').Defaults.create_tokenizer()
|
||
|
||
@pytest.fixture
|
||
def ja_tokenizer():
|
||
janome = pytest.importorskip("janome")
|
||
return util.get_lang_class('ja').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def th_tokenizer():
|
||
pythainlp = pytest.importorskip("pythainlp")
|
||
return util.get_lang_class('th').Defaults.create_tokenizer()
|
||
|
||
@pytest.fixture
|
||
def tr_tokenizer():
|
||
return util.get_lang_class('tr').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def ru_tokenizer():
|
||
pymorphy = pytest.importorskip('pymorphy2')
|
||
return util.get_lang_class('ru').Defaults.create_tokenizer()
|
||
|
||
|
||
@pytest.fixture
|
||
def stringstore():
|
||
return StringStore()
|
||
|
||
|
||
@pytest.fixture
|
||
def en_entityrecognizer():
|
||
return util.get_lang_class('en').Defaults.create_entity()
|
||
|
||
|
||
@pytest.fixture
|
||
def text_file():
|
||
return StringIO()
|
||
|
||
|
||
@pytest.fixture
|
||
def text_file_b():
|
||
return BytesIO()
|
||
|
||
|
||
def pytest_addoption(parser):
|
||
parser.addoption("--models", action="store_true",
|
||
help="include tests that require full models")
|
||
parser.addoption("--vectors", action="store_true",
|
||
help="include word vectors tests")
|
||
parser.addoption("--slow", action="store_true",
|
||
help="include slow tests")
|
||
|
||
for lang in _languages + ['all']:
|
||
parser.addoption("--%s" % lang, action="store_true", help="Use %s models" % lang)
|
||
for model in _models:
|
||
if model not in _languages:
|
||
parser.addoption("--%s" % model, action="store_true", help="Use %s model" % model)
|
||
|
||
|
||
def pytest_runtest_setup(item):
|
||
for opt in ['models', 'vectors', 'slow']:
|
||
if opt in item.keywords and not item.config.getoption("--%s" % opt):
|
||
pytest.skip("need --%s option to run" % opt)
|
||
|
||
# Check if test is marked with models and has arguments set, i.e. specific
|
||
# language. If so, skip test if flag not set.
|
||
if item.get_marker('models'):
|
||
for arg in item.get_marker('models').args:
|
||
if not item.config.getoption("--%s" % arg) and not item.config.getoption("--all"):
|
||
pytest.skip("need --%s or --all option to run" % arg)
|