mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-06 06:30:35 +03:00
various small fixes
This commit is contained in:
parent
1c71f2310c
commit
0b6d45eae1
|
@ -4,8 +4,10 @@ import random
|
|||
import warnings
|
||||
import srsly
|
||||
import spacy
|
||||
from spacy.gold import Example
|
||||
from spacy.util import minibatch, compounding
|
||||
|
||||
# TODO: further fix & test this script for v.3 ? (read_gold_data is never called)
|
||||
|
||||
LABEL = "ANIMAL"
|
||||
TRAIN_DATA = [
|
||||
|
@ -35,15 +37,13 @@ def read_raw_data(nlp, jsonl_loc):
|
|||
|
||||
|
||||
def read_gold_data(nlp, gold_loc):
|
||||
docs = []
|
||||
golds = []
|
||||
examples = []
|
||||
for json_obj in srsly.read_jsonl(gold_loc):
|
||||
doc = nlp.make_doc(json_obj["text"])
|
||||
ents = [(ent["start"], ent["end"], ent["label"]) for ent in json_obj["spans"]]
|
||||
gold = GoldParse(doc, entities=ents)
|
||||
docs.append(doc)
|
||||
golds.append(gold)
|
||||
return list(zip(docs, golds))
|
||||
example = Example.from_dict(doc, {"entities": ents})
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
||||
|
||||
def main(model_name, unlabelled_loc):
|
||||
|
|
|
@ -62,11 +62,10 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
|
|||
train_examples = []
|
||||
for text, cats in zip(train_texts, train_cats):
|
||||
doc = nlp.make_doc(text)
|
||||
gold = GoldParse(doc, cats=cats)
|
||||
example = Example.from_dict(doc, {"cats": cats})
|
||||
for cat in cats:
|
||||
textcat.add_label(cat)
|
||||
ex = Example.from_gold(gold, doc=doc)
|
||||
train_examples.append(ex)
|
||||
train_examples.append(example)
|
||||
|
||||
with nlp.select_pipes(enable="textcat"): # only train textcat
|
||||
optimizer = nlp.begin_training()
|
||||
|
|
|
@ -467,7 +467,7 @@ def train_while_improving(
|
|||
|
||||
Every iteration, the function yields out a tuple with:
|
||||
|
||||
* batch: A zipped sequence of Tuple[Doc, GoldParse] pairs.
|
||||
* batch: A list of Example objects.
|
||||
* info: A dict with various information about the last update (see below).
|
||||
* is_best_checkpoint: A value in None, False, True, indicating whether this
|
||||
was the best evaluation so far. You should use this to save the model
|
||||
|
|
|
@ -11,6 +11,7 @@ from spacy.util import fix_random_seed
|
|||
|
||||
from ..util import make_tempdir
|
||||
from spacy.pipeline.defaults import default_tok2vec
|
||||
from ...gold import Example
|
||||
|
||||
TRAIN_DATA = [
|
||||
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
||||
|
@ -50,21 +51,20 @@ def test_textcat_learns_multilabel():
|
|||
cats = {letter: float(w2 == letter) for letter in letters}
|
||||
docs.append((Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3), cats))
|
||||
random.shuffle(docs)
|
||||
model = TextCategorizer(nlp.vocab, width=8)
|
||||
textcat = TextCategorizer(nlp.vocab, width=8)
|
||||
for letter in letters:
|
||||
model.add_label(letter)
|
||||
optimizer = model.begin_training()
|
||||
textcat.add_label(letter)
|
||||
optimizer = textcat.begin_training()
|
||||
for i in range(30):
|
||||
losses = {}
|
||||
Ys = [GoldParse(doc, cats=cats) for doc, cats in docs]
|
||||
Xs = [doc for doc, cats in docs]
|
||||
model.update(Xs, Ys, sgd=optimizer, losses=losses)
|
||||
examples = [Example.from_dict(doc, {"cats": cats}) for doc, cat in docs]
|
||||
textcat.update(examples, sgd=optimizer, losses=losses)
|
||||
random.shuffle(docs)
|
||||
for w1 in letters:
|
||||
for w2 in letters:
|
||||
doc = Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3)
|
||||
truth = {letter: w2 == letter for letter in letters}
|
||||
model(doc)
|
||||
textcat(doc)
|
||||
for cat, score in doc.cats.items():
|
||||
if not truth[cat]:
|
||||
assert score < 0.5
|
||||
|
|
Loading…
Reference in New Issue
Block a user