mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-04 13:40:34 +03:00
Feature/example only (#5707)
* remove _convert_examples * fix test_gold, raise TypeError if tuples are used instead of Example's * throwing proper errors when the wrong type of objects are passed * fix deprectated format in tests * fix deprectated format in parser tests * fix tests for NEL, morph, senter, tagger, textcat * update regression tests with new Example format * use make_doc * more fixes to nlp.update calls * few more small fixes for rehearse and evaluate * only import ml_datasets if really necessary
This commit is contained in:
parent
63247cbe87
commit
fcbf899b08
|
@ -33,7 +33,7 @@ def read_raw_data(nlp, jsonl_loc):
|
||||||
for json_obj in srsly.read_jsonl(jsonl_loc):
|
for json_obj in srsly.read_jsonl(jsonl_loc):
|
||||||
if json_obj["text"].strip():
|
if json_obj["text"].strip():
|
||||||
doc = nlp.make_doc(json_obj["text"])
|
doc = nlp.make_doc(json_obj["text"])
|
||||||
yield doc
|
yield Example.from_dict(doc, {})
|
||||||
|
|
||||||
|
|
||||||
def read_gold_data(nlp, gold_loc):
|
def read_gold_data(nlp, gold_loc):
|
||||||
|
@ -52,7 +52,7 @@ def main(model_name, unlabelled_loc):
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
nlp = spacy.load(model_name)
|
nlp = spacy.load(model_name)
|
||||||
nlp.get_pipe("ner").add_label(LABEL)
|
nlp.get_pipe("ner").add_label(LABEL)
|
||||||
raw_docs = list(read_raw_data(nlp, unlabelled_loc))
|
raw_examples = list(read_raw_data(nlp, unlabelled_loc))
|
||||||
optimizer = nlp.resume_training()
|
optimizer = nlp.resume_training()
|
||||||
# Avoid use of Adam when resuming training. I don't understand this well
|
# Avoid use of Adam when resuming training. I don't understand this well
|
||||||
# yet, but I'm getting weird results from Adam. Try commenting out the
|
# yet, but I'm getting weird results from Adam. Try commenting out the
|
||||||
|
@ -61,20 +61,24 @@ def main(model_name, unlabelled_loc):
|
||||||
optimizer.learn_rate = 0.1
|
optimizer.learn_rate = 0.1
|
||||||
optimizer.b1 = 0.0
|
optimizer.b1 = 0.0
|
||||||
optimizer.b2 = 0.0
|
optimizer.b2 = 0.0
|
||||||
|
|
||||||
sizes = compounding(1.0, 4.0, 1.001)
|
sizes = compounding(1.0, 4.0, 1.001)
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
|
||||||
with nlp.select_pipes(enable="ner") and warnings.catch_warnings():
|
with nlp.select_pipes(enable="ner") and warnings.catch_warnings():
|
||||||
# show warnings for misaligned entity spans once
|
# show warnings for misaligned entity spans once
|
||||||
warnings.filterwarnings("once", category=UserWarning, module="spacy")
|
warnings.filterwarnings("once", category=UserWarning, module="spacy")
|
||||||
|
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(train_examples)
|
||||||
random.shuffle(raw_docs)
|
random.shuffle(raw_examples)
|
||||||
losses = {}
|
losses = {}
|
||||||
r_losses = {}
|
r_losses = {}
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
raw_batches = minibatch(raw_docs, size=4)
|
raw_batches = minibatch(raw_examples, size=4)
|
||||||
for batch in minibatch(TRAIN_DATA, size=sizes):
|
for batch in minibatch(train_examples, size=sizes):
|
||||||
nlp.update(batch, sgd=optimizer, drop=dropout, losses=losses)
|
nlp.update(batch, sgd=optimizer, drop=dropout, losses=losses)
|
||||||
raw_batch = list(next(raw_batches))
|
raw_batch = list(next(raw_batches))
|
||||||
nlp.rehearse(raw_batch, sgd=optimizer, losses=r_losses)
|
nlp.rehearse(raw_batch, sgd=optimizer, losses=r_losses)
|
||||||
|
|
|
@ -20,6 +20,8 @@ from pathlib import Path
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.pipeline import EntityRuler
|
from spacy.pipeline import EntityRuler
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
@ -94,7 +96,7 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
|
||||||
# Convert the texts to docs to make sure we have doc.ents set for the training examples.
|
# Convert the texts to docs to make sure we have doc.ents set for the training examples.
|
||||||
# Also ensure that the annotated examples correspond to known identifiers in the knowledge base.
|
# Also ensure that the annotated examples correspond to known identifiers in the knowledge base.
|
||||||
kb_ids = nlp.get_pipe("entity_linker").kb.get_entity_strings()
|
kb_ids = nlp.get_pipe("entity_linker").kb.get_entity_strings()
|
||||||
TRAIN_DOCS = []
|
train_examples = []
|
||||||
for text, annotation in TRAIN_DATA:
|
for text, annotation in TRAIN_DATA:
|
||||||
with nlp.select_pipes(disable="entity_linker"):
|
with nlp.select_pipes(disable="entity_linker"):
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
@ -109,17 +111,17 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
|
||||||
"Removed", kb_id, "from training because it is not in the KB."
|
"Removed", kb_id, "from training because it is not in the KB."
|
||||||
)
|
)
|
||||||
annotation_clean["links"][offset] = new_dict
|
annotation_clean["links"][offset] = new_dict
|
||||||
TRAIN_DOCS.append((doc, annotation_clean))
|
train_examples .append(Example.from_dict(doc, annotation_clean))
|
||||||
|
|
||||||
with nlp.select_pipes(enable="entity_linker"): # only train entity linker
|
with nlp.select_pipes(enable="entity_linker"): # only train entity linker
|
||||||
# reset and initialize the weights randomly
|
# reset and initialize the weights randomly
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
|
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
random.shuffle(TRAIN_DOCS)
|
random.shuffle(train_examples)
|
||||||
losses = {}
|
losses = {}
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
batches = minibatch(TRAIN_DOCS, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(
|
nlp.update(
|
||||||
batch,
|
batch,
|
||||||
|
|
|
@ -23,6 +23,7 @@ import plac
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import spacy
|
import spacy
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
|
||||||
|
@ -120,17 +121,19 @@ def main(model=None, output_dir=None, n_iter=15):
|
||||||
parser = nlp.create_pipe("parser")
|
parser = nlp.create_pipe("parser")
|
||||||
nlp.add_pipe(parser, first=True)
|
nlp.add_pipe(parser, first=True)
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for dep in annotations.get("deps", []):
|
for dep in annotations.get("deps", []):
|
||||||
parser.add_label(dep)
|
parser.add_label(dep)
|
||||||
|
|
||||||
with nlp.select_pipes(enable="parser"): # only train parser
|
with nlp.select_pipes(enable="parser"): # only train parser
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(train_examples)
|
||||||
losses = {}
|
losses = {}
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||||
print("Losses", losses)
|
print("Losses", losses)
|
||||||
|
|
|
@ -14,6 +14,7 @@ import plac
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import spacy
|
import spacy
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
from spacy.morphology import Morphology
|
from spacy.morphology import Morphology
|
||||||
|
|
||||||
|
@ -84,8 +85,10 @@ def main(lang="en", output_dir=None, n_iter=25):
|
||||||
morphologizer = nlp.create_pipe("morphologizer")
|
morphologizer = nlp.create_pipe("morphologizer")
|
||||||
nlp.add_pipe(morphologizer)
|
nlp.add_pipe(morphologizer)
|
||||||
|
|
||||||
# add labels
|
# add labels and create the Example instances
|
||||||
for _, annotations in TRAIN_DATA:
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
morph_labels = annotations.get("morphs")
|
morph_labels = annotations.get("morphs")
|
||||||
pos_labels = annotations.get("pos", [""] * len(annotations.get("morphs")))
|
pos_labels = annotations.get("pos", [""] * len(annotations.get("morphs")))
|
||||||
assert len(morph_labels) == len(pos_labels)
|
assert len(morph_labels) == len(pos_labels)
|
||||||
|
@ -98,10 +101,10 @@ def main(lang="en", output_dir=None, n_iter=25):
|
||||||
|
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(train_examples)
|
||||||
losses = {}
|
losses = {}
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||||
print("Losses", losses)
|
print("Losses", losses)
|
||||||
|
|
|
@ -17,6 +17,7 @@ import random
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import spacy
|
import spacy
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,8 +51,10 @@ def main(model=None, output_dir=None, n_iter=100):
|
||||||
else:
|
else:
|
||||||
ner = nlp.get_pipe("simple_ner")
|
ner = nlp.get_pipe("simple_ner")
|
||||||
|
|
||||||
# add labels
|
# add labels and create Example objects
|
||||||
for _, annotations in TRAIN_DATA:
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for ent in annotations.get("entities"):
|
for ent in annotations.get("entities"):
|
||||||
print("Add label", ent[2])
|
print("Add label", ent[2])
|
||||||
ner.add_label(ent[2])
|
ner.add_label(ent[2])
|
||||||
|
@ -68,10 +71,10 @@ def main(model=None, output_dir=None, n_iter=100):
|
||||||
"Transitions", list(enumerate(nlp.get_pipe("simple_ner").get_tag_names()))
|
"Transitions", list(enumerate(nlp.get_pipe("simple_ner").get_tag_names()))
|
||||||
)
|
)
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(train_examples)
|
||||||
losses = {}
|
losses = {}
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(
|
nlp.update(
|
||||||
batch,
|
batch,
|
||||||
|
|
|
@ -80,6 +80,10 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
|
||||||
print("Created blank 'en' model")
|
print("Created blank 'en' model")
|
||||||
# Add entity recognizer to model if it's not in the pipeline
|
# Add entity recognizer to model if it's not in the pipeline
|
||||||
# nlp.create_pipe works for built-ins that are registered with spaCy
|
# nlp.create_pipe works for built-ins that are registered with spaCy
|
||||||
|
train_examples = []
|
||||||
|
for text, annotation in TRAIN_DATA:
|
||||||
|
train_examples.append(TRAIN_DATA.from_dict(nlp(text), annotation))
|
||||||
|
|
||||||
if "ner" not in nlp.pipe_names:
|
if "ner" not in nlp.pipe_names:
|
||||||
ner = nlp.create_pipe("ner")
|
ner = nlp.create_pipe("ner")
|
||||||
nlp.add_pipe(ner)
|
nlp.add_pipe(ner)
|
||||||
|
@ -102,8 +106,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
|
||||||
sizes = compounding(1.0, 4.0, 1.001)
|
sizes = compounding(1.0, 4.0, 1.001)
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(train_examples)
|
||||||
batches = minibatch(TRAIN_DATA, size=sizes)
|
batches = minibatch(train_examples, size=sizes)
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(batch, sgd=optimizer, drop=0.35, losses=losses)
|
nlp.update(batch, sgd=optimizer, drop=0.35, losses=losses)
|
||||||
|
|
|
@ -14,6 +14,7 @@ import plac
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import spacy
|
import spacy
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,18 +60,20 @@ def main(model=None, output_dir=None, n_iter=15):
|
||||||
else:
|
else:
|
||||||
parser = nlp.get_pipe("parser")
|
parser = nlp.get_pipe("parser")
|
||||||
|
|
||||||
# add labels to the parser
|
# add labels to the parser and create the Example objects
|
||||||
for _, annotations in TRAIN_DATA:
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for dep in annotations.get("deps", []):
|
for dep in annotations.get("deps", []):
|
||||||
parser.add_label(dep)
|
parser.add_label(dep)
|
||||||
|
|
||||||
with nlp.select_pipes(enable="parser"): # only train parser
|
with nlp.select_pipes(enable="parser"): # only train parser
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(train_examples)
|
||||||
losses = {}
|
losses = {}
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||||
print("Losses", losses)
|
print("Losses", losses)
|
||||||
|
|
|
@ -17,6 +17,7 @@ import plac
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import spacy
|
import spacy
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,12 +59,16 @@ def main(lang="en", output_dir=None, n_iter=25):
|
||||||
tagger.add_label(tag, values)
|
tagger.add_label(tag, values)
|
||||||
nlp.add_pipe(tagger)
|
nlp.add_pipe(tagger)
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(train_examples)
|
||||||
losses = {}
|
losses = {}
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||||
print("Losses", losses)
|
print("Losses", losses)
|
||||||
|
|
|
@ -31,17 +31,20 @@ def profile_cli(
|
||||||
|
|
||||||
|
|
||||||
def profile(model: str, inputs: Optional[Path] = None, n_texts: int = 10000) -> None:
|
def profile(model: str, inputs: Optional[Path] = None, n_texts: int = 10000) -> None:
|
||||||
try:
|
|
||||||
import ml_datasets
|
|
||||||
except ImportError:
|
|
||||||
msg.fail(
|
|
||||||
"This command requires the ml_datasets library to be installed:"
|
|
||||||
"pip install ml_datasets",
|
|
||||||
exits=1,
|
|
||||||
)
|
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
inputs = _read_inputs(inputs, msg)
|
inputs = _read_inputs(inputs, msg)
|
||||||
if inputs is None:
|
if inputs is None:
|
||||||
|
try:
|
||||||
|
import ml_datasets
|
||||||
|
except ImportError:
|
||||||
|
msg.fail(
|
||||||
|
"This command, when run without an input file, "
|
||||||
|
"requires the ml_datasets library to be installed: "
|
||||||
|
"pip install ml_datasets",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
|
||||||
n_inputs = 25000
|
n_inputs = 25000
|
||||||
with msg.loading("Loading IMDB dataset via Thinc..."):
|
with msg.loading("Loading IMDB dataset via Thinc..."):
|
||||||
imdb_train, _ = ml_datasets.imdb()
|
imdb_train, _ = ml_datasets.imdb()
|
||||||
|
|
|
@ -12,7 +12,7 @@ from thinc.api import Model, use_pytorch_for_gpu_memory
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from ._app import app, Arg, Opt
|
from ._app import app, Arg, Opt
|
||||||
from ..gold import Corpus
|
from ..gold import Corpus, Example
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
@ -423,9 +423,8 @@ def train_while_improving(
|
||||||
|
|
||||||
if raw_text:
|
if raw_text:
|
||||||
random.shuffle(raw_text)
|
random.shuffle(raw_text)
|
||||||
raw_batches = util.minibatch(
|
raw_examples = [Example.from_dict(nlp.make_doc(rt["text"]), {}) for rt in raw_text]
|
||||||
(nlp.make_doc(rt["text"]) for rt in raw_text), size=8
|
raw_batches = util.minibatch(raw_examples, size=8)
|
||||||
)
|
|
||||||
|
|
||||||
for step, (epoch, batch) in enumerate(train_data):
|
for step, (epoch, batch) in enumerate(train_data):
|
||||||
dropout = next(dropouts)
|
dropout = next(dropouts)
|
||||||
|
|
|
@ -547,13 +547,13 @@ class Errors(object):
|
||||||
E972 = ("Example.__init__ got None for '{arg}'. Requires Doc.")
|
E972 = ("Example.__init__ got None for '{arg}'. Requires Doc.")
|
||||||
E973 = ("Unexpected type for NER data")
|
E973 = ("Unexpected type for NER data")
|
||||||
E974 = ("Unknown {obj} attribute: {key}")
|
E974 = ("Unknown {obj} attribute: {key}")
|
||||||
E975 = ("The method Example.from_dict expects a Doc as first argument, "
|
E975 = ("The method 'Example.from_dict' expects a Doc as first argument, "
|
||||||
"but got {type}")
|
"but got {type}")
|
||||||
E976 = ("The method Example.from_dict expects a dict as second argument, "
|
E976 = ("The method 'Example.from_dict' expects a dict as second argument, "
|
||||||
"but received None.")
|
"but received None.")
|
||||||
E977 = ("Can not compare a MorphAnalysis with a string object. "
|
E977 = ("Can not compare a MorphAnalysis with a string object. "
|
||||||
"This is likely a bug in spaCy, so feel free to open an issue.")
|
"This is likely a bug in spaCy, so feel free to open an issue.")
|
||||||
E978 = ("The {method} method of component {name} takes a list of Example objects, "
|
E978 = ("The '{method}' method of {name} takes a list of Example objects, "
|
||||||
"but found {types} instead.")
|
"but found {types} instead.")
|
||||||
E979 = ("Cannot convert {type} to an Example object.")
|
E979 = ("Cannot convert {type} to an Example object.")
|
||||||
E980 = ("Each link annotation should refer to a dictionary with at most one "
|
E980 = ("Each link annotation should refer to a dictionary with at most one "
|
||||||
|
|
|
@ -2,6 +2,7 @@ import random
|
||||||
import itertools
|
import itertools
|
||||||
import weakref
|
import weakref
|
||||||
import functools
|
import functools
|
||||||
|
from collections import Iterable
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -529,22 +530,6 @@ class Language(object):
|
||||||
def make_doc(self, text):
|
def make_doc(self, text):
|
||||||
return self.tokenizer(text)
|
return self.tokenizer(text)
|
||||||
|
|
||||||
def _convert_examples(self, examples):
|
|
||||||
converted_examples = []
|
|
||||||
if isinstance(examples, tuple):
|
|
||||||
examples = [examples]
|
|
||||||
for eg in examples:
|
|
||||||
if isinstance(eg, Example):
|
|
||||||
converted_examples.append(eg.copy())
|
|
||||||
elif isinstance(eg, tuple):
|
|
||||||
doc, annot = eg
|
|
||||||
if isinstance(doc, str):
|
|
||||||
doc = self.make_doc(doc)
|
|
||||||
converted_examples.append(Example.from_dict(doc, annot))
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E979.format(type=type(eg)))
|
|
||||||
return converted_examples
|
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
examples,
|
examples,
|
||||||
|
@ -557,7 +542,7 @@ class Language(object):
|
||||||
):
|
):
|
||||||
"""Update the models in the pipeline.
|
"""Update the models in the pipeline.
|
||||||
|
|
||||||
examples (iterable): A batch of `Example` or `Doc` objects.
|
examples (iterable): A batch of `Example` objects.
|
||||||
dummy: Should not be set - serves to catch backwards-incompatible scripts.
|
dummy: Should not be set - serves to catch backwards-incompatible scripts.
|
||||||
drop (float): The dropout rate.
|
drop (float): The dropout rate.
|
||||||
sgd (callable): An optimizer.
|
sgd (callable): An optimizer.
|
||||||
|
@ -569,10 +554,13 @@ class Language(object):
|
||||||
"""
|
"""
|
||||||
if dummy is not None:
|
if dummy is not None:
|
||||||
raise ValueError(Errors.E989)
|
raise ValueError(Errors.E989)
|
||||||
|
|
||||||
if len(examples) == 0:
|
if len(examples) == 0:
|
||||||
return
|
return
|
||||||
examples = self._convert_examples(examples)
|
if not isinstance(examples, Iterable):
|
||||||
|
raise TypeError(Errors.E978.format(name="language", method="update", types=type(examples)))
|
||||||
|
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
|
||||||
|
if wrong_types:
|
||||||
|
raise TypeError(Errors.E978.format(name="language", method="update", types=wrong_types))
|
||||||
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
if self._optimizer is None:
|
if self._optimizer is None:
|
||||||
|
@ -605,22 +593,26 @@ class Language(object):
|
||||||
initial ones. This is useful for keeping a pretrained model on-track,
|
initial ones. This is useful for keeping a pretrained model on-track,
|
||||||
even if you're updating it with a smaller set of examples.
|
even if you're updating it with a smaller set of examples.
|
||||||
|
|
||||||
examples (iterable): A batch of `Doc` objects.
|
examples (iterable): A batch of `Example` objects.
|
||||||
drop (float): The dropout rate.
|
drop (float): The dropout rate.
|
||||||
sgd (callable): An optimizer.
|
sgd (callable): An optimizer.
|
||||||
RETURNS (dict): Results from the update.
|
RETURNS (dict): Results from the update.
|
||||||
|
|
||||||
EXAMPLE:
|
EXAMPLE:
|
||||||
>>> raw_text_batches = minibatch(raw_texts)
|
>>> raw_text_batches = minibatch(raw_texts)
|
||||||
>>> for labelled_batch in minibatch(zip(train_docs, train_golds)):
|
>>> for labelled_batch in minibatch(examples):
|
||||||
>>> nlp.update(labelled_batch)
|
>>> nlp.update(labelled_batch)
|
||||||
>>> raw_batch = [nlp.make_doc(text) for text in next(raw_text_batches)]
|
>>> raw_batch = [Example.from_dict(nlp.make_doc(text), {}) for text in next(raw_text_batches)]
|
||||||
>>> nlp.rehearse(raw_batch)
|
>>> nlp.rehearse(raw_batch)
|
||||||
"""
|
"""
|
||||||
# TODO: document
|
# TODO: document
|
||||||
if len(examples) == 0:
|
if len(examples) == 0:
|
||||||
return
|
return
|
||||||
examples = self._convert_examples(examples)
|
if not isinstance(examples, Iterable):
|
||||||
|
raise TypeError(Errors.E978.format(name="language", method="rehearse", types=type(examples)))
|
||||||
|
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
|
||||||
|
if wrong_types:
|
||||||
|
raise TypeError(Errors.E978.format(name="language", method="rehearse", types=wrong_types))
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
if self._optimizer is None:
|
if self._optimizer is None:
|
||||||
self._optimizer = create_default_optimizer()
|
self._optimizer = create_default_optimizer()
|
||||||
|
@ -696,7 +688,7 @@ class Language(object):
|
||||||
component that has a .rehearse() method. Rehearsal is used to prevent
|
component that has a .rehearse() method. Rehearsal is used to prevent
|
||||||
models from "forgetting" their initialised "knowledge". To perform
|
models from "forgetting" their initialised "knowledge". To perform
|
||||||
rehearsal, collect samples of text you want the models to retain performance
|
rehearsal, collect samples of text you want the models to retain performance
|
||||||
on, and call nlp.rehearse() with a batch of Doc objects.
|
on, and call nlp.rehearse() with a batch of Example objects.
|
||||||
"""
|
"""
|
||||||
if cfg.get("device", -1) >= 0:
|
if cfg.get("device", -1) >= 0:
|
||||||
util.use_gpu(cfg["device"])
|
util.use_gpu(cfg["device"])
|
||||||
|
@ -728,7 +720,11 @@ class Language(object):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#evaluate
|
DOCS: https://spacy.io/api/language#evaluate
|
||||||
"""
|
"""
|
||||||
examples = self._convert_examples(examples)
|
if not isinstance(examples, Iterable):
|
||||||
|
raise TypeError(Errors.E978.format(name="language", method="evaluate", types=type(examples)))
|
||||||
|
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
|
||||||
|
if wrong_types:
|
||||||
|
raise TypeError(Errors.E978.format(name="language", method="evaluate", types=wrong_types))
|
||||||
if scorer is None:
|
if scorer is None:
|
||||||
scorer = Scorer(pipeline=self.pipeline)
|
scorer = Scorer(pipeline=self.pipeline)
|
||||||
if component_cfg is None:
|
if component_cfg is None:
|
||||||
|
|
|
@ -295,7 +295,7 @@ class Tagger(Pipe):
|
||||||
return
|
return
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
types = set([type(eg) for eg in examples])
|
types = set([type(eg) for eg in examples])
|
||||||
raise ValueError(Errors.E978.format(name="Tagger", method="update", types=types))
|
raise TypeError(Errors.E978.format(name="Tagger", method="update", types=types))
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
tag_scores, bp_tag_scores = self.model.begin_update(
|
tag_scores, bp_tag_scores = self.model.begin_update(
|
||||||
[eg.predicted for eg in examples])
|
[eg.predicted for eg in examples])
|
||||||
|
@ -321,7 +321,7 @@ class Tagger(Pipe):
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
types = set([type(eg) for eg in examples])
|
types = set([type(eg) for eg in examples])
|
||||||
raise ValueError(Errors.E978.format(name="Tagger", method="rehearse", types=types))
|
raise TypeError(Errors.E978.format(name="Tagger", method="rehearse", types=types))
|
||||||
if self._rehearsal_model is None:
|
if self._rehearsal_model is None:
|
||||||
return
|
return
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
|
@ -358,7 +358,7 @@ class Tagger(Pipe):
|
||||||
try:
|
try:
|
||||||
y = example.y
|
y = example.y
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(Errors.E978.format(name="Tagger", method="begin_training", types=type(example)))
|
raise TypeError(Errors.E978.format(name="Tagger", method="begin_training", types=type(example)))
|
||||||
for token in y:
|
for token in y:
|
||||||
tag = token.tag_
|
tag = token.tag_
|
||||||
if tag in orig_tag_map:
|
if tag in orig_tag_map:
|
||||||
|
@ -790,7 +790,7 @@ class ClozeMultitask(Pipe):
|
||||||
predictions, bp_predictions = self.model.begin_update([eg.predicted for eg in examples])
|
predictions, bp_predictions = self.model.begin_update([eg.predicted for eg in examples])
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
types = set([type(eg) for eg in examples])
|
types = set([type(eg) for eg in examples])
|
||||||
raise ValueError(Errors.E978.format(name="ClozeMultitask", method="rehearse", types=types))
|
raise TypeError(Errors.E978.format(name="ClozeMultitask", method="rehearse", types=types))
|
||||||
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
|
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
|
||||||
bp_predictions(d_predictions)
|
bp_predictions(d_predictions)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
|
@ -856,7 +856,7 @@ class TextCategorizer(Pipe):
|
||||||
return
|
return
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
types = set([type(eg) for eg in examples])
|
types = set([type(eg) for eg in examples])
|
||||||
raise ValueError(Errors.E978.format(name="TextCategorizer", method="update", types=types))
|
raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=types))
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
scores, bp_scores = self.model.begin_update(
|
scores, bp_scores = self.model.begin_update(
|
||||||
[eg.predicted for eg in examples]
|
[eg.predicted for eg in examples]
|
||||||
|
@ -879,7 +879,7 @@ class TextCategorizer(Pipe):
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
types = set([type(eg) for eg in examples])
|
types = set([type(eg) for eg in examples])
|
||||||
raise ValueError(Errors.E978.format(name="TextCategorizer", method="rehearse", types=types))
|
raise TypeError(Errors.E978.format(name="TextCategorizer", method="rehearse", types=types))
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
return
|
return
|
||||||
|
@ -940,7 +940,7 @@ class TextCategorizer(Pipe):
|
||||||
try:
|
try:
|
||||||
y = example.y
|
y = example.y
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(Errors.E978.format(name="TextCategorizer", method="update", types=type(example)))
|
raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=type(example)))
|
||||||
for cat in y.cats:
|
for cat in y.cats:
|
||||||
self.add_label(cat)
|
self.add_label(cat)
|
||||||
self.require_labels()
|
self.require_labels()
|
||||||
|
@ -1105,7 +1105,7 @@ class EntityLinker(Pipe):
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
types = set([type(eg) for eg in examples])
|
types = set([type(eg) for eg in examples])
|
||||||
raise ValueError(Errors.E978.format(name="EntityLinker", method="update", types=types))
|
raise TypeError(Errors.E978.format(name="EntityLinker", method="update", types=types))
|
||||||
if set_annotations:
|
if set_annotations:
|
||||||
# This seems simpler than other ways to get that exact output -- but
|
# This seems simpler than other ways to get that exact output -- but
|
||||||
# it does run the model twice :(
|
# it does run the model twice :(
|
||||||
|
|
|
@ -209,6 +209,10 @@ def test_train_empty():
|
||||||
]
|
]
|
||||||
|
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
train_examples = []
|
||||||
|
for t in train_data:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
|
||||||
ner = nlp.create_pipe("ner")
|
ner = nlp.create_pipe("ner")
|
||||||
ner.add_label("PERSON")
|
ner.add_label("PERSON")
|
||||||
nlp.add_pipe(ner, last=True)
|
nlp.add_pipe(ner, last=True)
|
||||||
|
@ -216,10 +220,9 @@ def test_train_empty():
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
for itn in range(2):
|
for itn in range(2):
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = util.minibatch(train_data)
|
batches = util.minibatch(train_examples)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
texts, annotations = zip(*batch)
|
nlp.update(batch, losses=losses)
|
||||||
nlp.update(train_data, losses=losses)
|
|
||||||
|
|
||||||
|
|
||||||
def test_overwrite_token():
|
def test_overwrite_token():
|
||||||
|
@ -328,7 +331,9 @@ def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the NER component - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the NER component - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
ner = nlp.create_pipe("ner")
|
ner = nlp.create_pipe("ner")
|
||||||
for _, annotations in TRAIN_DATA:
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for ent in annotations.get("entities"):
|
for ent in annotations.get("entities"):
|
||||||
ner.add_label(ent[2])
|
ner.add_label(ent[2])
|
||||||
nlp.add_pipe(ner)
|
nlp.add_pipe(ner)
|
||||||
|
@ -336,7 +341,7 @@ def test_overfitting_IO():
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["ner"] < 0.00001
|
assert losses["ner"] < 0.00001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from ..util import get_doc, apply_transition_sequence, make_tempdir
|
from ..util import get_doc, apply_transition_sequence, make_tempdir
|
||||||
from ... import util
|
from ... import util
|
||||||
|
from ...gold import Example
|
||||||
|
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
(
|
(
|
||||||
|
@ -189,7 +190,9 @@ def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the dependency parser - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the dependency parser - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
parser = nlp.create_pipe("parser")
|
parser = nlp.create_pipe("parser")
|
||||||
for _, annotations in TRAIN_DATA:
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for dep in annotations.get("deps", []):
|
for dep in annotations.get("deps", []):
|
||||||
parser.add_label(dep)
|
parser.add_label(dep)
|
||||||
nlp.add_pipe(parser)
|
nlp.add_pipe(parser)
|
||||||
|
@ -197,7 +200,7 @@ def test_overfitting_IO():
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["parser"] < 0.00001
|
assert losses["parser"] < 0.00001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.pipeline import EntityRuler
|
from spacy.pipeline import EntityRuler
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
|
@ -283,11 +284,10 @@ def test_overfitting_IO():
|
||||||
nlp.add_pipe(ruler)
|
nlp.add_pipe(ruler)
|
||||||
|
|
||||||
# Convert the texts to docs to make sure we have doc.ents set for the training examples
|
# Convert the texts to docs to make sure we have doc.ents set for the training examples
|
||||||
TRAIN_DOCS = []
|
train_examples = []
|
||||||
for text, annotation in TRAIN_DATA:
|
for text, annotation in TRAIN_DATA:
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
annotation_clean = annotation
|
train_examples.append(Example.from_dict(doc, annotation))
|
||||||
TRAIN_DOCS.append((doc, annotation_clean))
|
|
||||||
|
|
||||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||||
# Q2146908 (Russ Cochran): American golfer
|
# Q2146908 (Russ Cochran): American golfer
|
||||||
|
@ -309,7 +309,7 @@ def test_overfitting_IO():
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(TRAIN_DOCS, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["entity_linker"] < 0.001
|
assert losses["entity_linker"] < 0.001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
|
@ -33,7 +34,9 @@ def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the morphologizer - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the morphologizer - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
morphologizer = nlp.create_pipe("morphologizer")
|
morphologizer = nlp.create_pipe("morphologizer")
|
||||||
|
train_examples = []
|
||||||
for inst in TRAIN_DATA:
|
for inst in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(inst[0]), inst[1]))
|
||||||
for morph, pos in zip(inst[1]["morphs"], inst[1]["pos"]):
|
for morph, pos in zip(inst[1]["morphs"], inst[1]["pos"]):
|
||||||
morphologizer.add_label(morph + "|POS=" + pos)
|
morphologizer.add_label(morph + "|POS=" + pos)
|
||||||
nlp.add_pipe(morphologizer)
|
nlp.add_pipe(morphologizer)
|
||||||
|
@ -41,7 +44,7 @@ def test_overfitting_IO():
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["morphologizer"] < 0.00001
|
assert losses["morphologizer"] < 0.00001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
|
@ -34,12 +35,15 @@ def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
senter = nlp.create_pipe("senter")
|
senter = nlp.create_pipe("senter")
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
nlp.add_pipe(senter)
|
nlp.add_pipe(senter)
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
|
|
||||||
for i in range(200):
|
for i in range(200):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["senter"] < 0.001
|
assert losses["senter"] < 0.001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
|
@ -28,12 +29,15 @@ def test_overfitting_IO():
|
||||||
tagger = nlp.create_pipe("tagger")
|
tagger = nlp.create_pipe("tagger")
|
||||||
for tag, values in TAG_MAP.items():
|
for tag, values in TAG_MAP.items():
|
||||||
tagger.add_label(tag, values)
|
tagger.add_label(tag, values)
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
nlp.add_pipe(tagger)
|
nlp.add_pipe(tagger)
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["tagger"] < 0.00001
|
assert losses["tagger"] < 0.00001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
|
|
@ -85,7 +85,9 @@ def test_overfitting_IO():
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
nlp = English()
|
nlp = English()
|
||||||
textcat = nlp.create_pipe("textcat")
|
textcat = nlp.create_pipe("textcat")
|
||||||
for _, annotations in TRAIN_DATA:
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for label, value in annotations.get("cats").items():
|
for label, value in annotations.get("cats").items():
|
||||||
textcat.add_label(label)
|
textcat.add_label(label)
|
||||||
nlp.add_pipe(textcat)
|
nlp.add_pipe(textcat)
|
||||||
|
@ -93,7 +95,7 @@ def test_overfitting_IO():
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["textcat"] < 0.01
|
assert losses["textcat"] < 0.01
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
@ -134,11 +136,13 @@ def test_textcat_configs(textcat_config):
|
||||||
pipe_config = {"model": textcat_config}
|
pipe_config = {"model": textcat_config}
|
||||||
nlp = English()
|
nlp = English()
|
||||||
textcat = nlp.create_pipe("textcat", pipe_config)
|
textcat = nlp.create_pipe("textcat", pipe_config)
|
||||||
for _, annotations in TRAIN_DATA:
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for label, value in annotations.get("cats").items():
|
for label, value in annotations.get("cats").items():
|
||||||
textcat.add_label(label)
|
textcat.add_label(label)
|
||||||
nlp.add_pipe(textcat)
|
nlp.add_pipe(textcat)
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
from spacy import displacy
|
from spacy import displacy
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.ja import Japanese
|
from spacy.lang.ja import Japanese
|
||||||
from spacy.lang.xx import MultiLanguage
|
from spacy.lang.xx import MultiLanguage
|
||||||
|
@ -141,10 +142,10 @@ def test_issue2800():
|
||||||
"""Test issue that arises when too many labels are added to NER model.
|
"""Test issue that arises when too many labels are added to NER model.
|
||||||
Used to cause segfault.
|
Used to cause segfault.
|
||||||
"""
|
"""
|
||||||
train_data = []
|
|
||||||
train_data.extend([("One sentence", {"entities": []})])
|
|
||||||
entity_types = [str(i) for i in range(1000)]
|
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
train_data = []
|
||||||
|
train_data.extend([Example.from_dict(nlp.make_doc("One sentence"), {"entities": []})])
|
||||||
|
entity_types = [str(i) for i in range(1000)]
|
||||||
ner = nlp.create_pipe("ner")
|
ner = nlp.create_pipe("ner")
|
||||||
nlp.add_pipe(ner)
|
nlp.add_pipe(ner)
|
||||||
for entity_type in list(entity_types):
|
for entity_type in list(entity_types):
|
||||||
|
@ -153,8 +154,8 @@ def test_issue2800():
|
||||||
for i in range(20):
|
for i in range(20):
|
||||||
losses = {}
|
losses = {}
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
for statement, entities in train_data:
|
for example in train_data:
|
||||||
nlp.update((statement, entities), sgd=optimizer, losses=losses, drop=0.5)
|
nlp.update([example], sgd=optimizer, losses=losses, drop=0.5)
|
||||||
|
|
||||||
|
|
||||||
def test_issue2822(it_tokenizer):
|
def test_issue2822(it_tokenizer):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import spacy
|
import spacy
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,15 +13,15 @@ def test_issue3611():
|
||||||
]
|
]
|
||||||
y_train = ["offensive", "offensive", "inoffensive"]
|
y_train = ["offensive", "offensive", "inoffensive"]
|
||||||
|
|
||||||
# preparing the data
|
|
||||||
pos_cats = list()
|
|
||||||
for train_instance in y_train:
|
|
||||||
pos_cats.append({label: label == train_instance for label in unique_classes})
|
|
||||||
train_data = list(zip(x_train, [{"cats": cats} for cats in pos_cats]))
|
|
||||||
|
|
||||||
# set up the spacy model with a text categorizer component
|
|
||||||
nlp = spacy.blank("en")
|
nlp = spacy.blank("en")
|
||||||
|
|
||||||
|
# preparing the data
|
||||||
|
train_data = []
|
||||||
|
for text, train_instance in zip(x_train, y_train):
|
||||||
|
cat_dict = {label: label == train_instance for label in unique_classes}
|
||||||
|
train_data.append(Example.from_dict(nlp.make_doc(text), {"cats": cat_dict}))
|
||||||
|
|
||||||
|
# add a text categorizer component
|
||||||
textcat = nlp.create_pipe(
|
textcat = nlp.create_pipe(
|
||||||
"textcat",
|
"textcat",
|
||||||
config={"exclusive_classes": True, "architecture": "bow", "ngram_size": 2},
|
config={"exclusive_classes": True, "architecture": "bow", "ngram_size": 2},
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import spacy
|
import spacy
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,15 +13,15 @@ def test_issue4030():
|
||||||
]
|
]
|
||||||
y_train = ["offensive", "offensive", "inoffensive"]
|
y_train = ["offensive", "offensive", "inoffensive"]
|
||||||
|
|
||||||
# preparing the data
|
|
||||||
pos_cats = list()
|
|
||||||
for train_instance in y_train:
|
|
||||||
pos_cats.append({label: label == train_instance for label in unique_classes})
|
|
||||||
train_data = list(zip(x_train, [{"cats": cats} for cats in pos_cats]))
|
|
||||||
|
|
||||||
# set up the spacy model with a text categorizer component
|
|
||||||
nlp = spacy.blank("en")
|
nlp = spacy.blank("en")
|
||||||
|
|
||||||
|
# preparing the data
|
||||||
|
train_data = []
|
||||||
|
for text, train_instance in zip(x_train, y_train):
|
||||||
|
cat_dict = {label: label == train_instance for label in unique_classes}
|
||||||
|
train_data.append(Example.from_dict(nlp.make_doc(text), {"cats": cat_dict}))
|
||||||
|
|
||||||
|
# add a text categorizer component
|
||||||
textcat = nlp.create_pipe(
|
textcat = nlp.create_pipe(
|
||||||
"textcat",
|
"textcat",
|
||||||
config={"exclusive_classes": True, "architecture": "bow", "ngram_size": 2},
|
config={"exclusive_classes": True, "architecture": "bow", "ngram_size": 2},
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -7,9 +8,10 @@ import pytest
|
||||||
def test_issue4348():
|
def test_issue4348():
|
||||||
"""Test that training the tagger with empty data, doesn't throw errors"""
|
"""Test that training the tagger with empty data, doesn't throw errors"""
|
||||||
|
|
||||||
TRAIN_DATA = [("", {"tags": []}), ("", {"tags": []})]
|
|
||||||
|
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
example = Example.from_dict(nlp.make_doc(""), {"tags": []})
|
||||||
|
TRAIN_DATA = [example, example]
|
||||||
|
|
||||||
tagger = nlp.create_pipe("tagger")
|
tagger = nlp.create_pipe("tagger")
|
||||||
nlp.add_pipe(tagger)
|
nlp.add_pipe(tagger)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
|
|
||||||
|
|
||||||
def test_issue4924():
|
def test_issue4924():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
docs_golds = [("", {})]
|
example = Example.from_dict(nlp.make_doc(""), {})
|
||||||
nlp.evaluate(docs_golds)
|
nlp.evaluate([example])
|
||||||
|
|
|
@ -589,7 +589,7 @@ def test_tuple_format_implicit():
|
||||||
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
_train(train_data)
|
_train_tuples(train_data)
|
||||||
|
|
||||||
|
|
||||||
def test_tuple_format_implicit_invalid():
|
def test_tuple_format_implicit_invalid():
|
||||||
|
@ -605,20 +605,24 @@ def test_tuple_format_implicit_invalid():
|
||||||
]
|
]
|
||||||
|
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
_train(train_data)
|
_train_tuples(train_data)
|
||||||
|
|
||||||
|
|
||||||
def _train(train_data):
|
def _train_tuples(train_data):
|
||||||
nlp = English()
|
nlp = English()
|
||||||
ner = nlp.create_pipe("ner")
|
ner = nlp.create_pipe("ner")
|
||||||
ner.add_label("ORG")
|
ner.add_label("ORG")
|
||||||
ner.add_label("LOC")
|
ner.add_label("LOC")
|
||||||
nlp.add_pipe(ner)
|
nlp.add_pipe(ner)
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for t in train_data:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
from .util import add_vecs_to_vocab, assert_docs_equal
|
from .util import add_vecs_to_vocab, assert_docs_equal
|
||||||
|
from ..gold import Example
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -23,26 +24,45 @@ def test_language_update(nlp):
|
||||||
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
||||||
wrongkeyannots = {"LABEL": True}
|
wrongkeyannots = {"LABEL": True}
|
||||||
doc = Doc(nlp.vocab, words=text.split(" "))
|
doc = Doc(nlp.vocab, words=text.split(" "))
|
||||||
# Update with text and dict
|
example = Example.from_dict(doc, annots)
|
||||||
nlp.update((text, annots))
|
nlp.update([example])
|
||||||
|
|
||||||
|
# Not allowed to call with just one Example
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.update(example)
|
||||||
|
|
||||||
|
# Update with text and dict: not supported anymore since v.3
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.update((text, annots))
|
||||||
# Update with doc object and dict
|
# Update with doc object and dict
|
||||||
nlp.update((doc, annots))
|
with pytest.raises(TypeError):
|
||||||
# Update badly
|
nlp.update((doc, annots))
|
||||||
|
|
||||||
|
# Create examples badly
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
nlp.update((doc, None))
|
example = Example.from_dict(doc, None)
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
nlp.update((text, wrongkeyannots))
|
example = Example.from_dict(doc, wrongkeyannots)
|
||||||
|
|
||||||
|
|
||||||
def test_language_evaluate(nlp):
|
def test_language_evaluate(nlp):
|
||||||
text = "hello world"
|
text = "hello world"
|
||||||
annots = {"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}}
|
annots = {"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}}
|
||||||
doc = Doc(nlp.vocab, words=text.split(" "))
|
doc = Doc(nlp.vocab, words=text.split(" "))
|
||||||
# Evaluate with text and dict
|
example = Example.from_dict(doc, annots)
|
||||||
nlp.evaluate([(text, annots)])
|
nlp.evaluate([example])
|
||||||
|
|
||||||
|
# Not allowed to call with just one Example
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.evaluate(example)
|
||||||
|
|
||||||
|
# Evaluate with text and dict: not supported anymore since v.3
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.evaluate([(text, annots)])
|
||||||
# Evaluate with doc object and dict
|
# Evaluate with doc object and dict
|
||||||
nlp.evaluate([(doc, annots)])
|
with pytest.raises(TypeError):
|
||||||
with pytest.raises(Exception):
|
nlp.evaluate([(doc, annots)])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
nlp.evaluate([text, annots])
|
nlp.evaluate([text, annots])
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,8 +76,9 @@ def test_evaluate_no_pipe(nlp):
|
||||||
text = "hello world"
|
text = "hello world"
|
||||||
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
||||||
nlp = Language(Vocab())
|
nlp = Language(Vocab())
|
||||||
|
doc = nlp(text)
|
||||||
nlp.add_pipe(pipe)
|
nlp.add_pipe(pipe)
|
||||||
nlp.evaluate([(text, annots)])
|
nlp.evaluate([Example.from_dict(doc, annots)])
|
||||||
|
|
||||||
|
|
||||||
def vector_modification_pipe(doc):
|
def vector_modification_pipe(doc):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user