various small fixes

This commit is contained in:
svlandeg 2020-06-18 15:55:00 +02:00
parent 1c71f2310c
commit 0b6d45eae1
4 changed files with 16 additions and 17 deletions

View File

@ -4,8 +4,10 @@ import random
import warnings import warnings
import srsly import srsly
import spacy import spacy
from spacy.gold import Example
from spacy.util import minibatch, compounding from spacy.util import minibatch, compounding
# TODO: further fix & test this script for v.3 ? (read_gold_data is never called)
LABEL = "ANIMAL" LABEL = "ANIMAL"
TRAIN_DATA = [ TRAIN_DATA = [
@ -35,15 +37,13 @@ def read_raw_data(nlp, jsonl_loc):
def read_gold_data(nlp, gold_loc): def read_gold_data(nlp, gold_loc):
docs = [] examples = []
golds = []
for json_obj in srsly.read_jsonl(gold_loc): for json_obj in srsly.read_jsonl(gold_loc):
doc = nlp.make_doc(json_obj["text"]) doc = nlp.make_doc(json_obj["text"])
ents = [(ent["start"], ent["end"], ent["label"]) for ent in json_obj["spans"]] ents = [(ent["start"], ent["end"], ent["label"]) for ent in json_obj["spans"]]
gold = GoldParse(doc, entities=ents) example = Example.from_dict(doc, {"entities": ents})
docs.append(doc) examples.append(example)
golds.append(gold) return examples
return list(zip(docs, golds))
def main(model_name, unlabelled_loc): def main(model_name, unlabelled_loc):

View File

@ -62,11 +62,10 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
train_examples = [] train_examples = []
for text, cats in zip(train_texts, train_cats): for text, cats in zip(train_texts, train_cats):
doc = nlp.make_doc(text) doc = nlp.make_doc(text)
gold = GoldParse(doc, cats=cats) example = Example.from_dict(doc, {"cats": cats})
for cat in cats: for cat in cats:
textcat.add_label(cat) textcat.add_label(cat)
ex = Example.from_gold(gold, doc=doc) train_examples.append(example)
train_examples.append(ex)
with nlp.select_pipes(enable="textcat"): # only train textcat with nlp.select_pipes(enable="textcat"): # only train textcat
optimizer = nlp.begin_training() optimizer = nlp.begin_training()

View File

@ -467,7 +467,7 @@ def train_while_improving(
Every iteration, the function yields out a tuple with: Every iteration, the function yields out a tuple with:
* batch: A zipped sequence of Tuple[Doc, GoldParse] pairs. * batch: A list of Example objects.
* info: A dict with various information about the last update (see below). * info: A dict with various information about the last update (see below).
* is_best_checkpoint: A value in None, False, True, indicating whether this * is_best_checkpoint: A value in None, False, True, indicating whether this
was the best evaluation so far. You should use this to save the model was the best evaluation so far. You should use this to save the model

View File

@ -11,6 +11,7 @@ from spacy.util import fix_random_seed
from ..util import make_tempdir from ..util import make_tempdir
from spacy.pipeline.defaults import default_tok2vec from spacy.pipeline.defaults import default_tok2vec
from ...gold import Example
TRAIN_DATA = [ TRAIN_DATA = [
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}), ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
@ -50,21 +51,20 @@ def test_textcat_learns_multilabel():
cats = {letter: float(w2 == letter) for letter in letters} cats = {letter: float(w2 == letter) for letter in letters}
docs.append((Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3), cats)) docs.append((Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3), cats))
random.shuffle(docs) random.shuffle(docs)
model = TextCategorizer(nlp.vocab, width=8) textcat = TextCategorizer(nlp.vocab, width=8)
for letter in letters: for letter in letters:
model.add_label(letter) textcat.add_label(letter)
optimizer = model.begin_training() optimizer = textcat.begin_training()
for i in range(30): for i in range(30):
losses = {} losses = {}
Ys = [GoldParse(doc, cats=cats) for doc, cats in docs] examples = [Example.from_dict(doc, {"cats": cats}) for doc, cat in docs]
Xs = [doc for doc, cats in docs] textcat.update(examples, sgd=optimizer, losses=losses)
model.update(Xs, Ys, sgd=optimizer, losses=losses)
random.shuffle(docs) random.shuffle(docs)
for w1 in letters: for w1 in letters:
for w2 in letters: for w2 in letters:
doc = Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3) doc = Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3)
truth = {letter: w2 == letter for letter in letters} truth = {letter: w2 == letter for letter in letters}
model(doc) textcat(doc)
for cat, score in doc.cats.items(): for cat, score in doc.cats.items():
if not truth[cat]: if not truth[cat]:
assert score < 0.5 assert score < 0.5