diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index 5d326f841..7264fc56a 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -673,35 +673,34 @@ def debug_data( if "trainable_lemmatizer" in factory_names: msg.divider("Trainable Lemmatizer") - # NOTE: label_list is the string version of the trees - label_list: Set[str] = gold_train_data["lemmatizer_trees"] + trees_train: Set[str] = gold_train_data["lemmatizer_trees"] + trees_dev: Set[str] = gold_dev_data["lemmatizer_trees"] - # NOTE: label_list_full is the set of strigified trees from - # iterating through all trees after all lemmas have been added - label_list_full: Set[str] = gold_train_data["lemmatizer_trees_full"] - - # NOTE: model_labels is the tree IDs from iterating through the tree_id - # values in the component. - model_labels = _get_labels_from_trainable_lemmatizer(nlp) - - msg.info( - f"{len(label_list)} label(s) in train data (created from trees.add and trees.tree_to_str)" - ) - msg.info( - f"{len(label_list_full)} label(s) in train data (from iterating through tree_id)" - ) - msg.info( - f"{len(model_labels)} label(s) in component (from iterating through tree_id on component.trees" - ) - - labels = set(label_list) - missing_labels = model_labels - labels - - if missing_labels: + if gold_train_data["n_low_cardinality_lemmas"] > 0: msg.warn( - f"{len(missing_labels)} label(s) in model (component.tree)" - " and not present in the train data (trees.add and trees.tree_to_str)." + f"{gold_train_data['n_low_cardinality_lemmas']} docs with 1 or 0 unique lemmas." ) + else: + msg.good("Training docs have sufficient unique lemmas") + + train_not_dev = trees_train - trees_dev + if len(train_not_dev) != 0: + msg.warn(f"{len(train_not_dev)} labels were found only in the train data.") + else: + msg.good("Training data contains all lemmatizer trees in dev set.") + + if gold_train_data["n_low_cardinality_lemmas"] > 0: + msg.warn( + f"{gold_dev_data['n_low_cardinality_lemmas']} docs with 1 or 0 unique lemmas." + ) + else: + msg.good("Dev docs have sufficient unique lemmas") + + dev_not_train = trees_dev - trees_train + if len(dev_not_train) != 0: + msg.warn(f"{len(dev_not_train)} labels were found only in the dev data.") + else: + msg.good("Trees in dev data present in training data.") msg.divider("Summary") good_counts = msg.counts[MESSAGES.GOOD] @@ -765,9 +764,10 @@ def _compile_gold( "n_cats_bad_values": 0, "texts": set(), "lemmatizer_trees": set(), - "lemmatizer_trees_full": set(), + "n_low_cardinality_lemmas": 0, } - trees = EditTrees(nlp.vocab.strings) + if "trainable_lemmatizer" in factory_names: + trees = EditTrees(nlp.vocab.strings) for eg in examples: gold = eg.reference doc = eg.predicted @@ -898,19 +898,16 @@ def _compile_gold( if nonproj.contains_cycle(aligned_heads): data["n_cycles"] += 1 if "trainable_lemmatizer" in factory_names: - # NOTE: From EditTreeLemmatizer._labels_from_data - for token in eg.reference: + # from EditTreeLemmatizer._labels_from_data + lemma_set = set() + for token in gold: + lemma_set.add(token.lemma) if token.lemma != 0: tree_id = trees.add(token.text, token.lemma_) tree_str = trees.tree_to_str(tree_id) data["lemmatizer_trees"].add(tree_str) - if "trainable_lemmatizer" in factory_names: - # After the edittree is built, let's try and capture all the - # string trees in that structure - all_trees = set() - for tree_id in range(len(trees)): - all_trees.add(trees.tree_to_str(tree_id)) - data["lemmatizer_trees_full"] = all_trees + if len(lemma_set) < 2: + data["n_low_cardinality_lemmas"] += 1 return data @@ -1013,7 +1010,8 @@ def _get_labels_from_trainable_lemmatizer(nlp: Language) -> Set[str]: labels: Set[str] = set() for pipe_name in pipe_names: pipe = nlp.get_pipe(pipe_name) - for tree_id in range(len(pipe.trees)): + # This gets lemmatization trees and no partial subtrees + for tree_id in pipe.tree2label.keys(): labels.add(pipe.trees.tree_to_str(tree_id)) return labels