mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Merge pull request #5920 from explosion/fix/logging-warning-various
This commit is contained in:
commit
3272a63430
|
@ -14,7 +14,7 @@ from . import pipeline # noqa: F401
|
||||||
from .cli.info import info # noqa: F401
|
from .cli.info import info # noqa: F401
|
||||||
from .glossary import explain # noqa: F401
|
from .glossary import explain # noqa: F401
|
||||||
from .about import __version__ # noqa: F401
|
from .about import __version__ # noqa: F401
|
||||||
from .util import registry # noqa: F401
|
from .util import registry, logger # noqa: F401
|
||||||
|
|
||||||
from .errors import Errors
|
from .errors import Errors
|
||||||
from .language import Language
|
from .language import Language
|
||||||
|
|
|
@ -60,7 +60,6 @@ def evaluate(
|
||||||
fix_random_seed()
|
fix_random_seed()
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
require_gpu(use_gpu)
|
require_gpu(use_gpu)
|
||||||
util.set_env_log(False)
|
|
||||||
data_path = util.ensure_path(data_path)
|
data_path = util.ensure_path(data_path)
|
||||||
output_path = util.ensure_path(output)
|
output_path = util.ensure_path(output)
|
||||||
displacy_path = util.ensure_path(displacy_path)
|
displacy_path = util.ensure_path(displacy_path)
|
||||||
|
|
|
@ -9,6 +9,7 @@ from thinc.api import use_pytorch_for_gpu_memory, require_gpu, fix_random_seed
|
||||||
from thinc.api import Config, Optimizer
|
from thinc.api import Config, Optimizer
|
||||||
import random
|
import random
|
||||||
import typer
|
import typer
|
||||||
|
import logging
|
||||||
|
|
||||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code, get_sourced_components
|
from ._util import import_code, get_sourced_components
|
||||||
|
@ -17,7 +18,6 @@ from .. import util
|
||||||
from ..gold.example import Example
|
from ..gold.example import Example
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
|
||||||
|
|
||||||
# Don't remove - required to load the built-in architectures
|
# Don't remove - required to load the built-in architectures
|
||||||
from ..ml import models # noqa: F401
|
from ..ml import models # noqa: F401
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ def train_cli(
|
||||||
used to register custom functions and architectures that can then be
|
used to register custom functions and architectures that can then be
|
||||||
referenced in the config.
|
referenced in the config.
|
||||||
"""
|
"""
|
||||||
util.set_env_log(verbose)
|
util.logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
|
||||||
verify_cli_args(config_path, output_path)
|
verify_cli_args(config_path, output_path)
|
||||||
overrides = parse_config_overrides(ctx.args)
|
overrides = parse_config_overrides(ctx.args)
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
|
@ -102,9 +102,9 @@ def train(
|
||||||
if resume_components:
|
if resume_components:
|
||||||
with nlp.select_pipes(enable=resume_components):
|
with nlp.select_pipes(enable=resume_components):
|
||||||
msg.info(f"Resuming training for: {resume_components}")
|
msg.info(f"Resuming training for: {resume_components}")
|
||||||
nlp.resume_training()
|
nlp.resume_training(sgd=optimizer)
|
||||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
nlp.begin_training(lambda: train_corpus(nlp))
|
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
||||||
|
|
||||||
if tag_map:
|
if tag_map:
|
||||||
# Replace tag map with provided mapping
|
# Replace tag map with provided mapping
|
||||||
|
|
|
@ -55,12 +55,6 @@ class Warnings:
|
||||||
"loaded. (Shape: {shape})")
|
"loaded. (Shape: {shape})")
|
||||||
W021 = ("Unexpected hash collision in PhraseMatcher. Matches may be "
|
W021 = ("Unexpected hash collision in PhraseMatcher. Matches may be "
|
||||||
"incorrect. Modify PhraseMatcher._terminal_hash to fix.")
|
"incorrect. Modify PhraseMatcher._terminal_hash to fix.")
|
||||||
W022 = ("Training a new part-of-speech tagger using a model with no "
|
|
||||||
"lemmatization rules or data. This means that the trained model "
|
|
||||||
"may not be able to lemmatize correctly. If this is intentional "
|
|
||||||
"or the language you're using doesn't have lemmatization data, "
|
|
||||||
"you can ignore this warning. If this is surprising, make sure you "
|
|
||||||
"have the spacy-lookups-data package installed.")
|
|
||||||
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
|
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
|
||||||
"the Knowledge Base.")
|
"the Knowledge Base.")
|
||||||
W026 = ("Unable to set all sentence boundaries from dependency parses.")
|
W026 = ("Unable to set all sentence boundaries from dependency parses.")
|
||||||
|
|
|
@ -62,7 +62,7 @@ class Corpus:
|
||||||
if str(path) in seen:
|
if str(path) in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(str(path))
|
seen.add(str(path))
|
||||||
if path.parts[-1].startswith("."):
|
if path.parts and path.parts[-1].startswith("."):
|
||||||
continue
|
continue
|
||||||
elif path.is_dir():
|
elif path.is_dir():
|
||||||
paths.extend(path.iterdir())
|
paths.extend(path.iterdir())
|
||||||
|
|
|
@ -193,7 +193,8 @@ class Tok2Vec(Pipe):
|
||||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
for listener in self.listeners[:-1]:
|
for listener in self.listeners[:-1]:
|
||||||
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
if self.listeners:
|
||||||
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||||
if set_annotations:
|
if set_annotations:
|
||||||
self.set_annotations(docs, tokvecs)
|
self.set_annotations(docs, tokvecs)
|
||||||
return losses
|
return losses
|
||||||
|
|
|
@ -409,7 +409,7 @@ cdef class Parser(Pipe):
|
||||||
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
|
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
|
||||||
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
|
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
|
||||||
langs = ", ".join(util.LEXEME_NORM_LANGS)
|
langs = ", ".join(util.LEXEME_NORM_LANGS)
|
||||||
warnings.warn(Warnings.W033.format(model="parser or NER", langs=langs))
|
util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs))
|
||||||
actions = self.moves.get_actions(
|
actions = self.moves.get_actions(
|
||||||
examples=get_examples(),
|
examples=get_examples(),
|
||||||
min_freq=self.cfg['min_action_freq'],
|
min_freq=self.cfg['min_action_freq'],
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.lookups import Lookups
|
from spacy.lookups import Lookups
|
||||||
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
||||||
from spacy.gold import Example
|
from spacy.gold import Example
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
import logging
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
|
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
|
||||||
("I like London and Berlin.", {"entities": [(7, 13, "LOC"), (18, 24, "LOC")]}),
|
("I like London and Berlin.", {"entities": [(7, 13, "LOC"), (18, 24, "LOC")]}),
|
||||||
|
@ -56,6 +56,7 @@ def test_get_oracle_moves(tsys, doc, entity_annots):
|
||||||
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
|
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
def test_get_oracle_moves_negative_entities(tsys, doc, entity_annots):
|
def test_get_oracle_moves_negative_entities(tsys, doc, entity_annots):
|
||||||
entity_annots = [(s, e, "!" + label) for s, e, label in entity_annots]
|
entity_annots = [(s, e, "!" + label) for s, e, label in entity_annots]
|
||||||
example = Example.from_dict(doc, {"entities": entity_annots})
|
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||||
|
@ -332,19 +333,21 @@ def test_overfitting_IO():
|
||||||
assert ents2[0].label_ == "LOC"
|
assert ents2[0].label_ == "LOC"
|
||||||
|
|
||||||
|
|
||||||
def test_ner_warns_no_lookups():
|
def test_ner_warns_no_lookups(caplog):
|
||||||
nlp = English()
|
nlp = English()
|
||||||
assert nlp.lang in util.LEXEME_NORM_LANGS
|
assert nlp.lang in util.LEXEME_NORM_LANGS
|
||||||
nlp.vocab.lookups = Lookups()
|
nlp.vocab.lookups = Lookups()
|
||||||
assert not len(nlp.vocab.lookups)
|
assert not len(nlp.vocab.lookups)
|
||||||
nlp.add_pipe("ner")
|
nlp.add_pipe("ner")
|
||||||
with pytest.warns(UserWarning):
|
with caplog.at_level(logging.DEBUG):
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
|
assert "W033" in caplog.text
|
||||||
|
caplog.clear()
|
||||||
nlp.vocab.lookups.add_table("lexeme_norm")
|
nlp.vocab.lookups.add_table("lexeme_norm")
|
||||||
nlp.vocab.lookups.get_table("lexeme_norm")["a"] = "A"
|
nlp.vocab.lookups.get_table("lexeme_norm")["a"] = "A"
|
||||||
with pytest.warns(None) as record:
|
with caplog.at_level(logging.DEBUG):
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
assert not record.list
|
assert "W033" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
@Language.factory("blocker")
|
@Language.factory("blocker")
|
||||||
|
|
|
@ -25,7 +25,6 @@ def test_issue2070():
|
||||||
assert len(doc) == 11
|
assert len(doc) == 11
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue2179():
|
def test_issue2179():
|
||||||
"""Test that spurious 'extra_labels' aren't created when initializing NER."""
|
"""Test that spurious 'extra_labels' aren't created when initializing NER."""
|
||||||
nlp = Italian()
|
nlp = Italian()
|
||||||
|
@ -135,7 +134,6 @@ def test_issue2464(en_vocab):
|
||||||
assert len(matches) == 3
|
assert len(matches) == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue2482():
|
def test_issue2482():
|
||||||
"""Test we can serialize and deserialize a blank NER or parser model."""
|
"""Test we can serialize and deserialize a blank NER or parser model."""
|
||||||
nlp = Italian()
|
nlp = Italian()
|
||||||
|
|
|
@ -136,7 +136,6 @@ def test_issue2782(text, lang_cls):
|
||||||
assert doc[0].like_num
|
assert doc[0].like_num
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue2800():
|
def test_issue2800():
|
||||||
"""Test issue that arises when too many labels are added to NER model.
|
"""Test issue that arises when too many labels are added to NER model.
|
||||||
Used to cause segfault.
|
Used to cause segfault.
|
||||||
|
|
|
@ -90,7 +90,6 @@ def test_issue3199():
|
||||||
assert list(doc[0:3].noun_chunks) == []
|
assert list(doc[0:3].noun_chunks) == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue3209():
|
def test_issue3209():
|
||||||
"""Test issue that occurred in spaCy nightly where NER labels were being
|
"""Test issue that occurred in spaCy nightly where NER labels were being
|
||||||
mapped to classes incorrectly after loading the model, when the labels
|
mapped to classes incorrectly after loading the model, when the labels
|
||||||
|
|
|
@ -91,7 +91,6 @@ def test_issue_3526_3(en_vocab):
|
||||||
assert new_ruler.overwrite is not ruler.overwrite
|
assert new_ruler.overwrite is not ruler.overwrite
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue_3526_4(en_vocab):
|
def test_issue_3526_4(en_vocab):
|
||||||
nlp = Language(vocab=en_vocab)
|
nlp = Language(vocab=en_vocab)
|
||||||
patterns = [{"label": "ORG", "pattern": "Apple"}]
|
patterns = [{"label": "ORG", "pattern": "Apple"}]
|
||||||
|
@ -252,7 +251,6 @@ def test_issue3803():
|
||||||
assert [t.like_num for t in doc] == [True, True, True, True, True, True]
|
assert [t.like_num for t in doc] == [True, True, True, True, True, True]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue3830_no_subtok():
|
def test_issue3830_no_subtok():
|
||||||
"""Test that the parser doesn't have subtok label if not learn_tokens"""
|
"""Test that the parser doesn't have subtok label if not learn_tokens"""
|
||||||
config = {
|
config = {
|
||||||
|
@ -270,7 +268,6 @@ def test_issue3830_no_subtok():
|
||||||
assert "subtok" not in parser.labels
|
assert "subtok" not in parser.labels
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue3830_with_subtok():
|
def test_issue3830_with_subtok():
|
||||||
"""Test that the parser does have subtok label if learn_tokens=True."""
|
"""Test that the parser does have subtok label if learn_tokens=True."""
|
||||||
config = {
|
config = {
|
||||||
|
@ -333,7 +330,6 @@ def test_issue3879(en_vocab):
|
||||||
assert len(matcher(doc)) == 2 # fails because of a FP match 'is a test'
|
assert len(matcher(doc)) == 2 # fails because of a FP match 'is a test'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue3880():
|
def test_issue3880():
|
||||||
"""Test that `nlp.pipe()` works when an empty string ends the batch.
|
"""Test that `nlp.pipe()` works when an empty string ends the batch.
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,6 @@ def test_issue4030():
|
||||||
assert doc.cats["inoffensive"] == 0.0
|
assert doc.cats["inoffensive"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue4042():
|
def test_issue4042():
|
||||||
"""Test that serialization of an EntityRuler before NER works fine."""
|
"""Test that serialization of an EntityRuler before NER works fine."""
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
@ -110,7 +109,6 @@ def test_issue4042():
|
||||||
assert doc2.ents[0].label_ == "MY_ORG"
|
assert doc2.ents[0].label_ == "MY_ORG"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue4042_bug2():
|
def test_issue4042_bug2():
|
||||||
"""
|
"""
|
||||||
Test that serialization of an NER works fine when new labels were added.
|
Test that serialization of an NER works fine when new labels were added.
|
||||||
|
@ -242,7 +240,6 @@ def test_issue4190():
|
||||||
assert result_1b == result_2
|
assert result_1b == result_2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue4267():
|
def test_issue4267():
|
||||||
""" Test that running an entity_ruler after ner gives consistent results"""
|
""" Test that running an entity_ruler after ner gives consistent results"""
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
@ -324,7 +321,6 @@ def test_issue4313():
|
||||||
entity_scores[(start, end, label)] += score
|
entity_scores[(start, end, label)] += score
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue4348():
|
def test_issue4348():
|
||||||
"""Test that training the tagger with empty data, doesn't throw errors"""
|
"""Test that training the tagger with empty data, doesn't throw errors"""
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
|
|
@ -179,7 +179,6 @@ def test_issue4707():
|
||||||
assert "entity_ruler" in new_nlp.pipe_names
|
assert "entity_ruler" in new_nlp.pipe_names
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue4725_1():
|
def test_issue4725_1():
|
||||||
""" Ensure the pickling of the NER goes well"""
|
""" Ensure the pickling of the NER goes well"""
|
||||||
vocab = Vocab(vectors_name="test_vocab_add_vector")
|
vocab = Vocab(vectors_name="test_vocab_add_vector")
|
||||||
|
@ -198,7 +197,6 @@ def test_issue4725_1():
|
||||||
assert ner2.cfg["update_with_oracle_cut_size"] == 111
|
assert ner2.cfg["update_with_oracle_cut_size"] == 111
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue4725_2():
|
def test_issue4725_2():
|
||||||
# ensures that this runs correctly and doesn't hang or crash because of the global vectors
|
# ensures that this runs correctly and doesn't hang or crash because of the global vectors
|
||||||
# if it does crash, it's usually because of calling 'spawn' for multiprocessing (e.g. on Windows),
|
# if it does crash, it's usually because of calling 'spawn' for multiprocessing (e.g. on Windows),
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import pytest
|
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
||||||
def test_issue5152():
|
def test_issue5152():
|
||||||
# Test that the comparison between a Span and a Token, goes well
|
# Test that the comparison between a Span and a Token, goes well
|
||||||
# There was a bug when the number of tokens in the span equaled the number of characters in the token (!)
|
# There was a bug when the number of tokens in the span equaled the number of characters in the token (!)
|
||||||
|
@ -14,6 +13,8 @@ def test_issue5152():
|
||||||
span_2 = text[0:3] # Talk about being
|
span_2 = text[0:3] # Talk about being
|
||||||
span_3 = text_var[0:3] # Talk of being
|
span_3 = text_var[0:3] # Talk of being
|
||||||
token = y[0] # Let
|
token = y[0] # Let
|
||||||
assert span.similarity(token) == 0.0
|
with pytest.warns(UserWarning):
|
||||||
|
assert span.similarity(token) == 0.0
|
||||||
assert span.similarity(span_2) == 1.0
|
assert span.similarity(span_2) == 1.0
|
||||||
assert span_2.similarity(span_3) < 1.0
|
with pytest.warns(UserWarning):
|
||||||
|
assert span_2.similarity(span_3) < 1.0
|
||||||
|
|
|
@ -154,6 +154,7 @@ def test_example_from_dict_some_ner(en_vocab):
|
||||||
assert ner_tags == ["U-LOC", None, None, None]
|
assert ner_tags == ["U-LOC", None, None, None]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
def test_json2docs_no_ner(en_vocab):
|
def test_json2docs_no_ner(en_vocab):
|
||||||
data = [
|
data = [
|
||||||
{
|
{
|
||||||
|
@ -506,6 +507,7 @@ def test_roundtrip_docs_to_docbin(doc):
|
||||||
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
def test_make_orth_variants(doc):
|
def test_make_orth_variants(doc):
|
||||||
nlp = English()
|
nlp = English()
|
||||||
with make_tempdir() as tmpdir:
|
with make_tempdir() as tmpdir:
|
||||||
|
@ -586,7 +588,7 @@ def test_tuple_format_implicit():
|
||||||
("Uber blew through $1 million a week", {"entities": [(0, 4, "ORG")]}),
|
("Uber blew through $1 million a week", {"entities": [(0, 4, "ORG")]}),
|
||||||
(
|
(
|
||||||
"Spotify steps up Asia expansion",
|
"Spotify steps up Asia expansion",
|
||||||
{"entities": [(0, 8, "ORG"), (17, 21, "LOC")]},
|
{"entities": [(0, 7, "ORG"), (17, 21, "LOC")]},
|
||||||
),
|
),
|
||||||
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
||||||
]
|
]
|
||||||
|
@ -601,7 +603,7 @@ def test_tuple_format_implicit_invalid():
|
||||||
("Uber blew through $1 million a week", {"frumble": [(0, 4, "ORG")]}),
|
("Uber blew through $1 million a week", {"frumble": [(0, 4, "ORG")]}),
|
||||||
(
|
(
|
||||||
"Spotify steps up Asia expansion",
|
"Spotify steps up Asia expansion",
|
||||||
{"entities": [(0, 8, "ORG"), (17, 21, "LOC")]},
|
{"entities": [(0, 7, "ORG"), (17, 21, "LOC")]},
|
||||||
),
|
),
|
||||||
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
||||||
]
|
]
|
||||||
|
|
|
@ -46,6 +46,7 @@ def test_Example_from_dict_with_tags(pred_words, annots):
|
||||||
assert aligned_tags == ["NN" for _ in predicted]
|
assert aligned_tags == ["NN" for _ in predicted]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
def test_aligned_tags():
|
def test_aligned_tags():
|
||||||
pred_words = ["Apply", "some", "sunscreen", "unless", "you", "can", "not"]
|
pred_words = ["Apply", "some", "sunscreen", "unless", "you", "can", "not"]
|
||||||
gold_words = ["Apply", "some", "sun", "screen", "unless", "you", "cannot"]
|
gold_words = ["Apply", "some", "sun", "screen", "unless", "you", "cannot"]
|
||||||
|
@ -198,8 +199,8 @@ def test_Example_from_dict_with_entities(annots):
|
||||||
def test_Example_from_dict_with_entities_invalid(annots):
|
def test_Example_from_dict_with_entities_invalid(annots):
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
predicted = Doc(vocab, words=annots["words"])
|
predicted = Doc(vocab, words=annots["words"])
|
||||||
example = Example.from_dict(predicted, annots)
|
with pytest.warns(UserWarning):
|
||||||
# TODO: shouldn't this throw some sort of warning ?
|
example = Example.from_dict(predicted, annots)
|
||||||
assert len(list(example.reference.ents)) == 0
|
assert len(list(example.reference.ents)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import shlex
|
import shlex
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cupy.random
|
import cupy.random
|
||||||
|
@ -54,11 +55,14 @@ if TYPE_CHECKING:
|
||||||
from .vocab import Vocab # noqa: F401
|
from .vocab import Vocab # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
_PRINT_ENV = False
|
|
||||||
OOV_RANK = numpy.iinfo(numpy.uint64).max
|
OOV_RANK = numpy.iinfo(numpy.uint64).max
|
||||||
LEXEME_NORM_LANGS = ["da", "de", "el", "en", "id", "lb", "pt", "ru", "sr", "ta", "th"]
|
LEXEME_NORM_LANGS = ["da", "de", "el", "en", "id", "lb", "pt", "ru", "sr", "ta", "th"]
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
logger = logging.getLogger("spacy")
|
||||||
|
|
||||||
|
|
||||||
class registry(thinc.registry):
|
class registry(thinc.registry):
|
||||||
languages = catalogue.create("spacy", "languages", entry_points=True)
|
languages = catalogue.create("spacy", "languages", entry_points=True)
|
||||||
architectures = catalogue.create("spacy", "architectures", entry_points=True)
|
architectures = catalogue.create("spacy", "architectures", entry_points=True)
|
||||||
|
@ -109,11 +113,6 @@ class SimpleFrozenDict(dict):
|
||||||
raise NotImplementedError(self.error)
|
raise NotImplementedError(self.error)
|
||||||
|
|
||||||
|
|
||||||
def set_env_log(value: bool) -> None:
|
|
||||||
global _PRINT_ENV
|
|
||||||
_PRINT_ENV = value
|
|
||||||
|
|
||||||
|
|
||||||
def lang_class_is_loaded(lang: str) -> bool:
|
def lang_class_is_loaded(lang: str) -> bool:
|
||||||
"""Check whether a Language class is already loaded. Language classes are
|
"""Check whether a Language class is already loaded. Language classes are
|
||||||
loaded lazily, to avoid expensive setup code associated with the language
|
loaded lazily, to avoid expensive setup code associated with the language
|
||||||
|
@ -602,27 +601,6 @@ def get_async(stream, numpy_array):
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
def env_opt(name: str, default: Optional[Any] = None) -> Optional[Any]:
|
|
||||||
if type(default) is float:
|
|
||||||
type_convert = float
|
|
||||||
else:
|
|
||||||
type_convert = int
|
|
||||||
if "SPACY_" + name.upper() in os.environ:
|
|
||||||
value = type_convert(os.environ["SPACY_" + name.upper()])
|
|
||||||
if _PRINT_ENV:
|
|
||||||
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
|
|
||||||
return value
|
|
||||||
elif name in os.environ:
|
|
||||||
value = type_convert(os.environ[name])
|
|
||||||
if _PRINT_ENV:
|
|
||||||
print(name, "=", repr(value), "via", "$" + name)
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
if _PRINT_ENV:
|
|
||||||
print(name, "=", repr(default), "by default")
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def read_regex(path: Union[str, Path]) -> Pattern:
|
def read_regex(path: Union[str, Path]) -> Pattern:
|
||||||
path = ensure_path(path)
|
path = ensure_path(path)
|
||||||
with path.open(encoding="utf8") as file_:
|
with path.open(encoding="utf8") as file_:
|
||||||
|
@ -1067,24 +1045,7 @@ class DummyTokenizer:
|
||||||
|
|
||||||
|
|
||||||
def create_default_optimizer() -> Optimizer:
|
def create_default_optimizer() -> Optimizer:
|
||||||
# TODO: Do we still want to allow env_opt?
|
return Adam()
|
||||||
learn_rate = env_opt("learn_rate", 0.001)
|
|
||||||
beta1 = env_opt("optimizer_B1", 0.9)
|
|
||||||
beta2 = env_opt("optimizer_B2", 0.999)
|
|
||||||
eps = env_opt("optimizer_eps", 1e-8)
|
|
||||||
L2 = env_opt("L2_penalty", 1e-6)
|
|
||||||
grad_clip = env_opt("grad_norm_clip", 10.0)
|
|
||||||
L2_is_weight_decay = env_opt("L2_is_weight_decay", False)
|
|
||||||
optimizer = Adam(
|
|
||||||
learn_rate,
|
|
||||||
L2=L2,
|
|
||||||
beta1=beta1,
|
|
||||||
beta2=beta2,
|
|
||||||
eps=eps,
|
|
||||||
grad_clip=grad_clip,
|
|
||||||
L2_is_weight_decay=L2_is_weight_decay,
|
|
||||||
)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def minibatch(items, size):
|
def minibatch(items, size):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user