various small fixes

This commit is contained in:
svlandeg 2020-06-18 15:55:00 +02:00
parent 1c71f2310c
commit 0b6d45eae1
4 changed files with 16 additions and 17 deletions

View File

@ -4,8 +4,10 @@ import random
import warnings
import srsly
import spacy
from spacy.gold import Example
from spacy.util import minibatch, compounding
# TODO: further fix & test this script for v.3 ? (read_gold_data is never called)
LABEL = "ANIMAL"
TRAIN_DATA = [
@ -35,15 +37,13 @@ def read_raw_data(nlp, jsonl_loc):
def read_gold_data(nlp, gold_loc):
docs = []
golds = []
examples = []
for json_obj in srsly.read_jsonl(gold_loc):
doc = nlp.make_doc(json_obj["text"])
ents = [(ent["start"], ent["end"], ent["label"]) for ent in json_obj["spans"]]
gold = GoldParse(doc, entities=ents)
docs.append(doc)
golds.append(gold)
return list(zip(docs, golds))
example = Example.from_dict(doc, {"entities": ents})
examples.append(example)
return examples
def main(model_name, unlabelled_loc):

View File

@ -62,11 +62,10 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
train_examples = []
for text, cats in zip(train_texts, train_cats):
doc = nlp.make_doc(text)
gold = GoldParse(doc, cats=cats)
example = Example.from_dict(doc, {"cats": cats})
for cat in cats:
textcat.add_label(cat)
ex = Example.from_gold(gold, doc=doc)
train_examples.append(ex)
train_examples.append(example)
with nlp.select_pipes(enable="textcat"): # only train textcat
optimizer = nlp.begin_training()

View File

@ -467,7 +467,7 @@ def train_while_improving(
Every iteration, the function yields out a tuple with:
* batch: A zipped sequence of Tuple[Doc, GoldParse] pairs.
* batch: A list of Example objects.
* info: A dict with various information about the last update (see below).
* is_best_checkpoint: A value in None, False, True, indicating whether this
was the best evaluation so far. You should use this to save the model

View File

@ -11,6 +11,7 @@ from spacy.util import fix_random_seed
from ..util import make_tempdir
from spacy.pipeline.defaults import default_tok2vec
from ...gold import Example
TRAIN_DATA = [
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
@ -50,21 +51,20 @@ def test_textcat_learns_multilabel():
cats = {letter: float(w2 == letter) for letter in letters}
docs.append((Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3), cats))
random.shuffle(docs)
model = TextCategorizer(nlp.vocab, width=8)
textcat = TextCategorizer(nlp.vocab, width=8)
for letter in letters:
model.add_label(letter)
optimizer = model.begin_training()
textcat.add_label(letter)
optimizer = textcat.begin_training()
for i in range(30):
losses = {}
Ys = [GoldParse(doc, cats=cats) for doc, cats in docs]
Xs = [doc for doc, cats in docs]
model.update(Xs, Ys, sgd=optimizer, losses=losses)
examples = [Example.from_dict(doc, {"cats": cats}) for doc, cat in docs]
textcat.update(examples, sgd=optimizer, losses=losses)
random.shuffle(docs)
for w1 in letters:
for w2 in letters:
doc = Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3)
truth = {letter: w2 == letter for letter in letters}
model(doc)
textcat(doc)
for cat, score in doc.cats.items():
if not truth[cat]:
assert score < 0.5