From b71a11ff6dd7b47582fbffd45121c05ff3b89977 Mon Sep 17 00:00:00 2001 From: adrianeboyd Date: Thu, 2 Apr 2020 14:46:32 +0200 Subject: [PATCH] Update morphologizer (#5108) * Add pos and morph scoring to Scorer Add pos, morph, and morph_per_type to `Scorer`. Report pos and morph accuracy in `spacy evaluate`. * Update morphologizer for v3 * switch to tagger-based morphologizer * use `spacy.HashCharEmbedCNN` for morphologizer defaults * add `Doc.is_morphed` flag * Add morphologizer to train CLI * Add basic morphologizer pipeline tests * Add simple morphologizer training example * Remove subword_features from CharEmbed models Remove `subword_features` argument from `spacy.HashCharEmbedCNN.v1` and `spacy.HashCharEmbedBiLSTM.v1` since in these cases `subword_features` is always `False`. * Rename setting in morphologizer example Use `with_pos_tags` instead of `without_pos_tags`. * Fix kwargs for spacy.HashCharEmbedBiLSTM.v1 * Remove defaults for spacy.HashCharEmbedBiLSTM.v1 Remove default `nM/nC` for `spacy.HashCharEmbedBiLSTM.v1`. * Set random seed for textcat overfitting test --- examples/training/train_morphologizer.py | 133 ++++++++++ spacy/cli/evaluate.py | 4 +- spacy/cli/train.py | 14 +- .../defaults/morphologizer_defaults.cfg | 1 - spacy/ml/models/tok2vec.py | 7 +- spacy/pipeline/morphologizer.pyx | 237 +++++++++--------- spacy/scorer.py | 63 ++++- spacy/tests/pipeline/test_morphologizer.py | 49 ++++ spacy/tests/pipeline/test_textcat.py | 2 + spacy/tests/test_scorer.py | 75 ++++++ spacy/tokens/doc.pxd | 1 + 11 files changed, 458 insertions(+), 128 deletions(-) create mode 100644 examples/training/train_morphologizer.py create mode 100644 spacy/tests/pipeline/test_morphologizer.py diff --git a/examples/training/train_morphologizer.py b/examples/training/train_morphologizer.py new file mode 100644 index 000000000..aec114de7 --- /dev/null +++ b/examples/training/train_morphologizer.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# coding: utf8 +""" +A simple example for training a morphologizer. For more details, see +the documentation: +* Training: https://spacy.io/usage/training + +Compatible with: spaCy v3.0.0+ +Last tested with: v3.0.0 +""" +from __future__ import unicode_literals, print_function + +import plac +import random +from pathlib import Path +import spacy +from spacy.util import minibatch, compounding +from spacy.morphology import Morphology + + +# Usually you'll read this in, of course. Data formats vary. Ensure your +# strings are unicode and that the number of tags assigned matches spaCy's +# tokenization. If not, you can always add a 'words' key to the annotations +# that specifies the gold-standard tokenization, e.g.: +# ("Eatblueham", {'words': ['Eat', 'blue', 'ham'], 'tags': ['V', 'J', 'N']}) +TRAIN_DATA = [ + ( + "I like green eggs", + { + "morphs": [ + "PronType=Prs|Person=1", + "VerbForm=Fin", + "Degree=Pos", + "Number=Plur", + ], + "pos": ["PRON", "VERB", "ADJ", "NOUN"], + }, + ), + ( + "Eat blue ham", + { + "morphs": ["VerbForm=Inf", "Degree=Pos", "Number=Sing"], + "pos": ["VERB", "ADJ", "NOUN"], + }, + ), + ( + "She was blue", + { + "morphs": ["PronType=Prs|Person=3", "VerbForm=Fin", "Degree=Pos"], + "pos": ["PRON", "VERB", "ADJ"], + }, + ), + ( + "He was blue today", + { + "morphs": ["PronType=Prs|Person=3", "VerbForm=Fin", "Degree=Pos", ""], + "pos": ["PRON", "VERB", "ADJ", "ADV"], + }, + ), +] + +# The POS tags are optional, set `with_pos_tags = False` to omit them for +# this example: +with_pos_tags = True + +if not with_pos_tags: + for i in range(len(TRAIN_DATA)): + del TRAIN_DATA[i][1]["pos"] + + +@plac.annotations( + lang=("ISO Code of language to use", "option", "l", str), + output_dir=("Optional output directory", "option", "o", Path), + n_iter=("Number of training iterations", "option", "n", int), +) +def main(lang="en", output_dir=None, n_iter=25): + """Create a new model, set up the pipeline and train the tagger. In order to + train the tagger with a custom tag map, we're creating a new Language + instance with a custom vocab. + """ + nlp = spacy.blank(lang) + # add the tagger to the pipeline + # nlp.create_pipe works for built-ins that are registered with spaCy + morphologizer = nlp.create_pipe("morphologizer") + nlp.add_pipe(morphologizer) + + # add labels + for _, annotations in TRAIN_DATA: + morph_labels = annotations.get("morphs") + pos_labels = annotations.get("pos", [""] * len(annotations.get("morphs"))) + assert len(morph_labels) == len(pos_labels) + for morph, pos in zip(morph_labels, pos_labels): + morph_dict = Morphology.feats_to_dict(morph) + if pos: + morph_dict["POS"] = pos + morph = Morphology.dict_to_feats(morph_dict) + morphologizer.add_label(morph) + + optimizer = nlp.begin_training() + for i in range(n_iter): + random.shuffle(TRAIN_DATA) + losses = {} + # batch up the examples using spaCy's minibatch + batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001)) + for batch in batches: + nlp.update(batch, sgd=optimizer, losses=losses) + print("Losses", losses) + + # test the trained model + test_text = "I like blue eggs" + doc = nlp(test_text) + print("Morphs", [(t.text, t.morph) for t in doc]) + + # save model to output directory + if output_dir is not None: + output_dir = Path(output_dir) + if not output_dir.exists(): + output_dir.mkdir() + nlp.to_disk(output_dir) + print("Saved model to", output_dir) + + # test the save model + print("Loading from", output_dir) + nlp2 = spacy.load(output_dir) + doc = nlp2(test_text) + print("Morphs", [(t.text, t.morph) for t in doc]) + + +if __name__ == "__main__": + plac.call(main) + +# Expected output: +# Morphs [('I', POS=PRON|Person=1|PronType=Prs), ('like', POS=VERB|VerbForm=Fin), ('blue', Degree=Pos|POS=ADJ), ('eggs', Number=Plur|POS=NOUN)] diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index e047f1283..94813e732 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -43,7 +43,9 @@ def evaluate( "Words": nwords, "Words/s": f"{nwords / (end - begin):.0f}", "TOK": f"{scorer.token_acc:.2f}", - "POS": f"{scorer.tags_acc:.2f}", + "TAG": f"{scorer.tags_acc:.2f}", + "POS": f"{scorer.pos_acc:.2f}", + "MORPH": f"{scorer.morphs_acc:.2f}", "UAS": f"{scorer.uas:.2f}", "LAS": f"{scorer.las:.2f}", "NER P": f"{scorer.ents_p:.2f}", diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 56020e4ff..5fa09da78 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -221,6 +221,8 @@ def train( config_loc = default_dir / "parser_defaults.cfg" elif pipe == "tagger": config_loc = default_dir / "tagger_defaults.cfg" + elif pipe == "morphologizer": + config_loc = default_dir / "morphologizer_defaults.cfg" elif pipe == "ner": config_loc = default_dir / "ner_defaults.cfg" elif pipe == "textcat": @@ -590,6 +592,8 @@ def _score_for_model(meta): acc = meta["accuracy"] if "tagger" in pipes: mean_acc.append(acc["tags_acc"]) + if "morphologizer" in pipes: + mean_acc.append((acc["morphs_acc"] + acc["pos_acc"]) / 2) if "parser" in pipes: mean_acc.append((acc["uas"] + acc["las"]) / 2) if "ner" in pipes: @@ -672,13 +676,15 @@ def _find_best(experiment_dir, component): def _get_metrics(component): if component == "parser": - return ("las", "uas", "las_per_type", "token_acc", "sent_f") + return ("las", "uas", "las_per_type", "sent_f", "token_acc") elif component == "tagger": return ("tags_acc", "token_acc") + elif component == "morphologizer": + return ("morphs_acc", "pos_acc", "token_acc") elif component == "ner": return ("ents_f", "ents_p", "ents_r", "ents_per_type", "token_acc") elif component == "senter": - return ("sent_f", "sent_p", "sent_r") + return ("sent_f", "sent_p", "sent_r", "token_acc") elif component == "textcat": return ("textcat_score", "token_acc") return ("token_acc",) @@ -691,6 +697,9 @@ def _configure_training_output(pipeline, use_gpu, has_beam_widths): if pipe == "tagger": row_head.extend(["Tag Loss ", " Tag % "]) output_stats.extend(["tag_loss", "tags_acc"]) + elif pipe == "morphologizer" or pipe == "morphologizertagger": + row_head.extend(["Morph Loss ", " Morph % ", " POS % "]) + output_stats.extend(["morph_loss", "morphs_acc", "pos_acc"]) elif pipe == "parser": row_head.extend( ["Dep Loss ", " UAS ", " LAS ", "Sent P", "Sent R", "Sent F"] @@ -731,6 +740,7 @@ def _get_progress( scores["dep_loss"] = losses.get("parser", 0.0) scores["ner_loss"] = losses.get("ner", 0.0) scores["tag_loss"] = losses.get("tagger", 0.0) + scores["morph_loss"] = losses.get("morphologizer", 0.0) scores["textcat_loss"] = losses.get("textcat", 0.0) scores["senter_loss"] = losses.get("senter", 0.0) scores["cpu_wps"] = cpu_wps diff --git a/spacy/ml/models/defaults/morphologizer_defaults.cfg b/spacy/ml/models/defaults/morphologizer_defaults.cfg index 80e776c4f..150eca507 100644 --- a/spacy/ml/models/defaults/morphologizer_defaults.cfg +++ b/spacy/ml/models/defaults/morphologizer_defaults.cfg @@ -9,6 +9,5 @@ depth = 4 embed_size = 7000 window_size = 1 maxout_pieces = 3 -subword_features = true nM = 64 nC = 8 diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 81820e56b..a2e8f589a 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -74,7 +74,6 @@ def hash_charembed_cnn( embed_size, maxout_pieces, window_size, - subword_features, nM, nC, ): @@ -87,7 +86,7 @@ def hash_charembed_cnn( bilstm_depth=0, maxout_pieces=maxout_pieces, window_size=window_size, - subword_features=subword_features, + subword_features=False, char_embed=True, nM=nM, nC=nC, @@ -116,7 +115,7 @@ def hash_embed_bilstm_v1( @registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1") def hash_char_embed_bilstm_v1( - pretrained_vectors, width, depth, embed_size, subword_features, nM, nC, maxout_pieces + pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC ): # Allows using character embeddings by setting nC, nM and char_embed=True return build_Tok2Vec_model( @@ -127,7 +126,7 @@ def hash_char_embed_bilstm_v1( conv_depth=0, maxout_pieces=maxout_pieces, window_size=1, - subword_features=subword_features, + subword_features=False, char_embed=True, nM=nM, nC=nC, diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index be9b166bf..7a2bc3b17 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -1,166 +1,169 @@ +# cython: infer_types=True, profile=True cimport numpy as np import numpy -from collections import defaultdict -from thinc.api import chain, list2array, to_categorical, get_array_module -from thinc.util import copy_array +import srsly +from thinc.api import to_categorical from ..tokens.doc cimport Doc from ..vocab cimport Vocab from ..morphology cimport Morphology +from ..parts_of_speech import IDS as POS_IDS +from ..symbols import POS from .. import util from ..language import component from ..util import link_vectors_to_models, create_default_optimizer from ..errors import Errors, TempErrors -from .pipes import Pipe +from .pipes import Tagger, _load_cfg +from .. import util @component("morphologizer", assigns=["token.morph", "token.pos"]) -class Morphologizer(Pipe): +class Morphologizer(Tagger): def __init__(self, vocab, model, **cfg): self.vocab = vocab self.model = model + self._rehearsal_model = None self.cfg = dict(sorted(cfg.items())) - self._class_map = self.vocab.morphology.create_class_map() # Morphology.create_class_map() ? + self.cfg.setdefault("labels", {}) + self.cfg.setdefault("morph_pos", {}) @property def labels(self): - return self.vocab.morphology.tag_names + return tuple(self.cfg["labels"].keys()) - @property - def tok2vec(self): - if self.model in (None, True, False): - return None - else: - return chain(self.model.get_ref("tok2vec"), list2array()) - - def __call__(self, doc): - features, tokvecs = self.predict([doc]) - self.set_annotations([doc], features, tensors=tokvecs) - return doc - - def pipe(self, stream, batch_size=128, n_threads=-1): - for docs in util.minibatch(stream, size=batch_size): - docs = list(docs) - features, tokvecs = self.predict(docs) - self.set_annotations(docs, features, tensors=tokvecs) - yield from docs + def add_label(self, label): + if not isinstance(label, str): + raise ValueError(Errors.E187) + if label in self.labels: + return 0 + morph = Morphology.feats_to_dict(label) + norm_morph_pos = self.vocab.strings[self.vocab.morphology.add(morph)] + pos = morph.get("POS", "") + if norm_morph_pos not in self.cfg["labels"]: + self.cfg["labels"][norm_morph_pos] = norm_morph_pos + self.cfg["morph_pos"][norm_morph_pos] = POS_IDS[pos] + return 1 def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs): + for example in get_examples(): + for i, morph in enumerate(example.token_annotation.morphs): + pos = example.token_annotation.get_pos(i) + morph = Morphology.feats_to_dict(morph) + norm_morph = self.vocab.strings[self.vocab.morphology.add(morph)] + if pos: + morph["POS"] = pos + norm_morph_pos = self.vocab.strings[self.vocab.morphology.add(morph)] + if norm_morph_pos not in self.cfg["labels"]: + self.cfg["labels"][norm_morph_pos] = norm_morph + self.cfg["morph_pos"][norm_morph_pos] = POS_IDS[pos] self.set_output(len(self.labels)) self.model.initialize() + link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() return sgd - def predict(self, docs): - if not any(len(doc) for doc in docs): - # Handle case where there are no tokens in any docs. - n_labels = self.model.get_dim("nO") - guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs] - tokvecs = self.model.ops.alloc((0, self.model.get_ref("tok2vec").get_dim("nO"))) - return guesses, tokvecs - tokvecs = self.model.get_ref("tok2vec")(docs) - scores = self.model.get_ref("softmax")(tokvecs) - return scores, tokvecs - - def set_annotations(self, docs, batch_scores, tensors=None): + def set_annotations(self, docs, batch_tag_ids): if isinstance(docs, Doc): docs = [docs] cdef Doc doc cdef Vocab vocab = self.vocab - offsets = [self._class_map.get_field_offset(field) - for field in self._class_map.fields] for i, doc in enumerate(docs): - doc_scores = batch_scores[i] - doc_guesses = scores_to_guesses(doc_scores, self.model.get_ref("softmax").attrs["nOs"]) - # Convert the neuron indices into feature IDs. - doc_feat_ids = numpy.zeros((len(doc), len(self._class_map.fields)), dtype='i') - for j in range(len(doc)): - for k, offset in enumerate(offsets): - if doc_guesses[j, k] == 0: - doc_feat_ids[j, k] = 0 - else: - doc_feat_ids[j, k] = offset + doc_guesses[j, k] - # Get the set of feature names. - feats = {self._class_map.col2info[f][2] for f in doc_feat_ids[j]} - if "NIL" in feats: - feats.remove("NIL") - # Now add the analysis, and set the hash. - doc.c[j].morph = self.vocab.morphology.add(feats) - if doc[j].morph.pos != 0: - doc.c[j].pos = doc[j].morph.pos + doc_tag_ids = batch_tag_ids[i] + if hasattr(doc_tag_ids, "get"): + doc_tag_ids = doc_tag_ids.get() + for j, tag_id in enumerate(doc_tag_ids): + morph = self.labels[tag_id] + doc.c[j].morph = self.vocab.morphology.add(self.cfg["labels"][morph]) + doc.c[j].pos = self.cfg["morph_pos"][morph] - def update(self, examples, drop=0., sgd=None, losses=None): - if losses is not None and self.name not in losses: - losses[self.name] = 0. - - docs = [self._get_doc(ex) for ex in examples] - tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop) - loss, d_tag_scores = self.get_loss(examples, tag_scores) - bp_tag_scores(d_tag_scores, sgd=sgd) - - if losses is not None: - losses[self.name] += loss + doc.is_morphed = True def get_loss(self, examples, scores): - guesses = [] - for doc_scores in scores: - guesses.append(scores_to_guesses(doc_scores, self.model.get_ref("softmax").attrs["nOs"])) - guesses = self.model.ops.xp.vstack(guesses) - scores = self.model.ops.xp.vstack(scores) - if not isinstance(scores, numpy.ndarray): - scores = scores.get() - if not isinstance(guesses, numpy.ndarray): - guesses = guesses.get() + scores = self.model.ops.flatten(scores) + tag_index = {tag: i for i, tag in enumerate(self.labels)} cdef int idx = 0 - # Do this on CPU, as we can't vectorize easily. - target = numpy.zeros(scores.shape, dtype='f') - field_sizes = self.model.get_ref("softmax").attrs["nOs"] - for example in examples: - doc = example.doc - gold = example.gold - for t, features in enumerate(gold.morphology): - if features is None: - target[idx] = scores[idx] + correct = numpy.zeros((scores.shape[0],), dtype="i") + guesses = scores.argmax(axis=1) + known_labels = numpy.ones((scores.shape[0], 1), dtype="f") + for ex in examples: + gold = ex.gold + for i in range(len(gold.morphs)): + pos = gold.pos[i] if i < len(gold.pos) else "" + morph = gold.morphs[i] + feats = Morphology.feats_to_dict(morph) + if pos: + feats["POS"] = pos + if len(feats) > 0: + morph = self.vocab.strings[self.vocab.morphology.add(feats)] + if morph == "": + morph = Morphology.EMPTY_MORPH + if morph is None: + correct[idx] = guesses[idx] + elif morph in tag_index: + correct[idx] = tag_index[morph] else: - gold_fields = {} - for feature in features: - field = self._class_map.feat2field[feature] - gold_fields[field] = self._class_map.feat2offset[feature] - for field in self._class_map.fields: - field_id = self._class_map.field2id[field] - col_offset = self._class_map.field2col[field] - if field_id in gold_fields: - target[idx, col_offset + gold_fields[field_id]] = 1. - else: - target[idx, col_offset] = 1. - #print(doc[t]) - #for col, info in enumerate(self._class_map.col2info): - # print(col, info, scores[idx, col], target[idx, col]) + correct[idx] = 0 + known_labels[idx] = 0. idx += 1 - target = self.model.ops.asarray(target, dtype='f') - scores = self.model.ops.asarray(scores, dtype='f') - d_scores = scores - target + correct = self.model.ops.xp.array(correct, dtype="i") + d_scores = scores - to_categorical(correct, n_classes=scores.shape[1]) + d_scores *= self.model.ops.asarray(known_labels) loss = (d_scores**2).sum() - docs = [self._get_doc(ex) for ex in examples] + docs = [ex.doc for ex in examples] d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs]) return float(loss), d_scores - def use_params(self, params): - with self.model.use_params(params): - yield + def to_bytes(self, exclude=tuple(), **kwargs): + serialize = {} + serialize["model"] = self.model.to_bytes + serialize["vocab"] = self.vocab.to_bytes + serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) + exclude = util.get_serialization_exclude(serialize, exclude, kwargs) + return util.to_bytes(serialize, exclude) -def scores_to_guesses(scores, out_sizes): - xp = get_array_module(scores) - guesses = xp.zeros((scores.shape[0], len(out_sizes)), dtype='i') - offset = 0 - for i, size in enumerate(out_sizes): - slice_ = scores[:, offset : offset + size] - col_guesses = slice_.argmax(axis=1) - guesses[:, i] = col_guesses - offset += size - return guesses + def from_bytes(self, bytes_data, exclude=tuple(), **kwargs): + def load_model(b): + try: + self.model.from_bytes(b) + except AttributeError: + raise ValueError(Errors.E149) + + deserialize = { + "vocab": lambda b: self.vocab.from_bytes(b), + "cfg": lambda b: self.cfg.update(srsly.json_loads(b)), + "model": lambda b: load_model(b), + } + exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) + util.from_bytes(bytes_data, deserialize, exclude) + return self + + def to_disk(self, path, exclude=tuple(), **kwargs): + serialize = { + "vocab": lambda p: self.vocab.to_disk(p), + "model": lambda p: p.open("wb").write(self.model.to_bytes()), + "cfg": lambda p: srsly.write_json(p, self.cfg), + } + exclude = util.get_serialization_exclude(serialize, exclude, kwargs) + util.to_disk(path, serialize, exclude) + + def from_disk(self, path, exclude=tuple(), **kwargs): + def load_model(p): + with p.open("rb") as file_: + try: + self.model.from_bytes(file_.read()) + except AttributeError: + raise ValueError(Errors.E149) + + deserialize = { + "vocab": lambda p: self.vocab.from_disk(p), + "cfg": lambda p: self.cfg.update(_load_cfg(p)), + "model": load_model, + } + exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) + util.from_disk(path, deserialize, exclude) + return self diff --git a/spacy/scorer.py b/spacy/scorer.py index 82b10a77d..7e2466be7 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -81,6 +81,9 @@ class Scorer(object): self.labelled = PRFScore() self.labelled_per_dep = dict() self.tags = PRFScore() + self.pos = PRFScore() + self.morphs = PRFScore() + self.morphs_per_feat = dict() self.sent_starts = PRFScore() self.ner = PRFScore() self.ner_per_ents = dict() @@ -111,6 +114,29 @@ class Scorer(object): """ return self.tags.fscore * 100 + @property + def pos_acc(self): + """RETURNS (float): Part-of-speech tag accuracy (coarse grained pos, + i.e. `Token.pos`). + """ + return self.pos.fscore * 100 + + @property + def morphs_acc(self): + """RETURNS (float): Morph tag accuracy (morphological features, + i.e. `Token.morph`). + """ + return self.morphs.fscore * 100 + + @property + def morphs_per_type(self): + """RETURNS (dict): Scores per dependency label. + """ + return { + k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100} + for k, v in self.morphs_per_feat.items() + } + @property def sent_p(self): """RETURNS (float): F-score for identification of sentence starts. @@ -231,6 +257,9 @@ class Scorer(object): "ents_f": self.ents_f, "ents_per_type": self.ents_per_type, "tags_acc": self.tags_acc, + "pos_acc": self.pos_acc, + "morphs_acc": self.morphs_acc, + "morphs_per_type": self.morphs_per_type, "sent_p": self.sent_p, "sent_r": self.sent_r, "sent_f": self.sent_f, @@ -264,12 +293,23 @@ class Scorer(object): gold_deps = set() gold_deps_per_dep = {} gold_tags = set() + gold_pos = set() + gold_morphs = set() + gold_morphs_per_feat = {} gold_sent_starts = set() gold_ents = set(tags_to_entities(orig.entities)) - for id_, tag, head, dep, sent_start in zip( - orig.ids, orig.tags, orig.heads, orig.deps, orig.sent_starts - ): + for id_, tag, pos, morph, head, dep, sent_start in zip(orig.ids, orig.tags, orig.pos, orig.morphs, orig.heads, orig.deps, orig.sent_starts): gold_tags.add((id_, tag)) + gold_pos.add((id_, pos)) + gold_morphs.add((id_, morph)) + if morph: + for feat in morph.split("|"): + field, values = feat.split("=") + if field not in self.morphs_per_feat: + self.morphs_per_feat[field] = PRFScore() + if field not in gold_morphs_per_feat: + gold_morphs_per_feat[field] = set() + gold_morphs_per_feat[field].add((id_, feat)) if sent_start: gold_sent_starts.add(id_) if dep not in (None, "") and dep.lower() not in punct_labels: @@ -282,6 +322,9 @@ class Scorer(object): cand_deps = set() cand_deps_per_dep = {} cand_tags = set() + cand_pos = set() + cand_morphs = set() + cand_morphs_per_feat = {} cand_sent_starts = set() for token in doc: if token.orth_.isspace(): @@ -292,6 +335,16 @@ class Scorer(object): else: self.tokens.tp += 1 cand_tags.add((gold_i, token.tag_)) + cand_pos.add((gold_i, token.pos_)) + cand_morphs.add((gold_i, token.morph_)) + if token.morph_: + for feat in token.morph_.split("|"): + field, values = feat.split("=") + if field not in self.morphs_per_feat: + self.morphs_per_feat[field] = PRFScore() + if field not in cand_morphs_per_feat: + cand_morphs_per_feat[field] = set() + cand_morphs_per_feat[field].add((gold_i, feat)) if token.is_sent_start: cand_sent_starts.add(gold_i) if token.dep_.lower() not in punct_labels and token.orth_.strip(): @@ -340,6 +393,10 @@ class Scorer(object): # Score for all ents self.ner.score_set(cand_ents, gold_ents) self.tags.score_set(cand_tags, gold_tags) + self.pos.score_set(cand_pos, gold_pos) + self.morphs.score_set(cand_morphs, gold_morphs) + for field in self.morphs_per_feat: + self.morphs_per_feat[field].score_set(cand_morphs_per_feat.get(field, set()), gold_morphs_per_feat.get(field, set())) self.sent_starts.score_set(cand_sent_starts, gold_sent_starts) self.labelled.score_set(cand_deps, gold_deps) for dep in self.labelled_per_dep: diff --git a/spacy/tests/pipeline/test_morphologizer.py b/spacy/tests/pipeline/test_morphologizer.py new file mode 100644 index 000000000..f9307afc2 --- /dev/null +++ b/spacy/tests/pipeline/test_morphologizer.py @@ -0,0 +1,49 @@ +import pytest + +from spacy import util +from spacy.lang.en import English +from spacy.language import Language +from spacy.tests.util import make_tempdir + + +def test_label_types(): + nlp = Language() + nlp.add_pipe(nlp.create_pipe("morphologizer")) + nlp.get_pipe("morphologizer").add_label("Feat=A") + with pytest.raises(ValueError): + nlp.get_pipe("morphologizer").add_label(9) + + +TRAIN_DATA = [ + ("I like green eggs", {"morphs": ["Feat=N", "Feat=V", "Feat=J", "Feat=N"], "pos": ["NOUN", "VERB", "ADJ", "NOUN"]}), + ("Eat blue ham", {"morphs": ["Feat=V", "Feat=J", "Feat=N"], "pos": ["VERB", "ADJ", "NOUN"]}), +] + + +def test_overfitting_IO(): + # Simple test to try and quickly overfit the morphologizer - ensuring the ML models work correctly + nlp = English() + morphologizer = nlp.create_pipe("morphologizer") + for inst in TRAIN_DATA: + for morph, pos in zip(inst[1]["morphs"], inst[1]["pos"]): + morphologizer.add_label(morph + "|POS=" + pos) + nlp.add_pipe(morphologizer) + optimizer = nlp.begin_training() + + for i in range(50): + losses = {} + nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses) + assert losses["morphologizer"] < 0.00001 + + # test the trained model + test_text = "I like blue eggs" + doc = nlp(test_text) + gold_morphs = ["Feat=N|POS=NOUN", "Feat=V|POS=VERB", "Feat=J|POS=ADJ", "Feat=N|POS=NOUN"] + assert gold_morphs == [t.morph_ for t in doc] + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + assert gold_morphs == [t.morph_ for t in doc2] diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 38c980428..b091ec0de 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -8,6 +8,7 @@ from spacy.language import Language from spacy.pipeline import TextCategorizer from spacy.tokens import Doc from spacy.gold import GoldParse +from spacy.util import fix_random_seed from ..util import make_tempdir from ...ml.models.defaults import default_tok2vec @@ -82,6 +83,7 @@ def test_label_types(): def test_overfitting_IO(): # Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly + fix_random_seed(0) nlp = English() textcat = nlp.create_pipe("textcat") for _, annotations in TRAIN_DATA: diff --git a/spacy/tests/test_scorer.py b/spacy/tests/test_scorer.py index efaf80b4f..d750a8202 100644 --- a/spacy/tests/test_scorer.py +++ b/spacy/tests/test_scorer.py @@ -5,6 +5,7 @@ from spacy.gold import Example, GoldParse from spacy.scorer import Scorer, ROCAUCScore from spacy.scorer import _roc_auc_score, _roc_curve from .util import get_doc +from spacy.lang.en import English test_las_apple = [ [ @@ -39,6 +40,43 @@ test_ner_apple = [ ] ] +@pytest.fixture +def tagged_doc(): + text = "Sarah's sister flew to Silicon Valley via London." + tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."] + pos = [ + "PROPN", + "PART", + "NOUN", + "VERB", + "ADP", + "PROPN", + "PROPN", + "ADP", + "PROPN", + "PUNCT", + ] + morphs = [ + "NounType=prop|Number=sing", + "Poss=yes", + "Number=sing", + "Tense=past|VerbForm=fin", + "", + "NounType=prop|Number=sing", + "NounType=prop|Number=sing", + "", + "NounType=prop|Number=sing", + "PunctType=peri", + ] + nlp = English() + doc = nlp(text) + for i in range(len(tags)): + doc[i].tag_ = tags[i] + doc[i].pos_ = pos[i] + doc[i].morph_ = morphs[i] + doc.is_tagged = True + return doc + def test_las_per_type(en_vocab): # Gold and Doc are identical @@ -139,6 +177,43 @@ def test_ner_per_type(en_vocab): assert results["ents_per_type"]["ORG"]["f"] == approx(66.66666) +def test_tag_score(tagged_doc): + # Gold and Doc are identical + scorer = Scorer() + gold = GoldParse( + tagged_doc, + tags=[t.tag_ for t in tagged_doc], + pos=[t.pos_ for t in tagged_doc], + morphs=[t.morph_ for t in tagged_doc] + ) + scorer.score((tagged_doc, gold)) + results = scorer.scores + + assert results["tags_acc"] == 100 + assert results["pos_acc"] == 100 + assert results["morphs_acc"] == 100 + assert results["morphs_per_type"]["NounType"]["f"] == 100 + + # Gold and Doc are identical + scorer = Scorer() + tags = [t.tag_ for t in tagged_doc] + tags[0] = "NN" + pos = [t.pos_ for t in tagged_doc] + pos[1] = "X" + morphs = [t.morph_ for t in tagged_doc] + morphs[1] = "Number=sing" + morphs[2] = "Number=plur" + gold = GoldParse(tagged_doc, tags=tags, pos=pos, morphs=morphs) + scorer.score((tagged_doc, gold)) + results = scorer.scores + + assert results["tags_acc"] == 90 + assert results["pos_acc"] == 90 + assert results["morphs_acc"] == approx(80) + assert results["morphs_per_type"]["Poss"]["f"] == 0.0 + assert results["morphs_per_type"]["Number"]["f"] == approx(72.727272) + + def test_roc_auc_score(): # Binary classification, toy tests from scikit-learn test suite y_true = [0, 1] diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index 7f231887f..050a6b898 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -50,6 +50,7 @@ cdef class Doc: cdef public bint is_tagged cdef public bint is_parsed + cdef public bint is_morphed cdef public float sentiment