This commit is contained in:
Peter Baumgartner 2022-09-01 13:00:18 -04:00
parent 19145c6c92
commit 64b57204e2

View File

@ -673,35 +673,34 @@ def debug_data(
if "trainable_lemmatizer" in factory_names: if "trainable_lemmatizer" in factory_names:
msg.divider("Trainable Lemmatizer") msg.divider("Trainable Lemmatizer")
# NOTE: label_list is the string version of the trees trees_train: Set[str] = gold_train_data["lemmatizer_trees"]
label_list: Set[str] = gold_train_data["lemmatizer_trees"] trees_dev: Set[str] = gold_dev_data["lemmatizer_trees"]
# NOTE: label_list_full is the set of strigified trees from if gold_train_data["n_low_cardinality_lemmas"] > 0:
# iterating through all trees after all lemmas have been added
label_list_full: Set[str] = gold_train_data["lemmatizer_trees_full"]
# NOTE: model_labels is the tree IDs from iterating through the tree_id
# values in the component.
model_labels = _get_labels_from_trainable_lemmatizer(nlp)
msg.info(
f"{len(label_list)} label(s) in train data (created from trees.add and trees.tree_to_str)"
)
msg.info(
f"{len(label_list_full)} label(s) in train data (from iterating through tree_id)"
)
msg.info(
f"{len(model_labels)} label(s) in component (from iterating through tree_id on component.trees"
)
labels = set(label_list)
missing_labels = model_labels - labels
if missing_labels:
msg.warn( msg.warn(
f"{len(missing_labels)} label(s) in model (component.tree)" f"{gold_train_data['n_low_cardinality_lemmas']} docs with 1 or 0 unique lemmas."
" and not present in the train data (trees.add and trees.tree_to_str)."
) )
else:
msg.good("Training docs have sufficient unique lemmas")
train_not_dev = trees_train - trees_dev
if len(train_not_dev) != 0:
msg.warn(f"{len(train_not_dev)} labels were found only in the train data.")
else:
msg.good("Training data contains all lemmatizer trees in dev set.")
if gold_train_data["n_low_cardinality_lemmas"] > 0:
msg.warn(
f"{gold_dev_data['n_low_cardinality_lemmas']} docs with 1 or 0 unique lemmas."
)
else:
msg.good("Dev docs have sufficient unique lemmas")
dev_not_train = trees_dev - trees_train
if len(dev_not_train) != 0:
msg.warn(f"{len(dev_not_train)} labels were found only in the dev data.")
else:
msg.good("Trees in dev data present in training data.")
msg.divider("Summary") msg.divider("Summary")
good_counts = msg.counts[MESSAGES.GOOD] good_counts = msg.counts[MESSAGES.GOOD]
@ -765,8 +764,9 @@ def _compile_gold(
"n_cats_bad_values": 0, "n_cats_bad_values": 0,
"texts": set(), "texts": set(),
"lemmatizer_trees": set(), "lemmatizer_trees": set(),
"lemmatizer_trees_full": set(), "n_low_cardinality_lemmas": 0,
} }
if "trainable_lemmatizer" in factory_names:
trees = EditTrees(nlp.vocab.strings) trees = EditTrees(nlp.vocab.strings)
for eg in examples: for eg in examples:
gold = eg.reference gold = eg.reference
@ -898,19 +898,16 @@ def _compile_gold(
if nonproj.contains_cycle(aligned_heads): if nonproj.contains_cycle(aligned_heads):
data["n_cycles"] += 1 data["n_cycles"] += 1
if "trainable_lemmatizer" in factory_names: if "trainable_lemmatizer" in factory_names:
# NOTE: From EditTreeLemmatizer._labels_from_data # from EditTreeLemmatizer._labels_from_data
for token in eg.reference: lemma_set = set()
for token in gold:
lemma_set.add(token.lemma)
if token.lemma != 0: if token.lemma != 0:
tree_id = trees.add(token.text, token.lemma_) tree_id = trees.add(token.text, token.lemma_)
tree_str = trees.tree_to_str(tree_id) tree_str = trees.tree_to_str(tree_id)
data["lemmatizer_trees"].add(tree_str) data["lemmatizer_trees"].add(tree_str)
if "trainable_lemmatizer" in factory_names: if len(lemma_set) < 2:
# After the edittree is built, let's try and capture all the data["n_low_cardinality_lemmas"] += 1
# string trees in that structure
all_trees = set()
for tree_id in range(len(trees)):
all_trees.add(trees.tree_to_str(tree_id))
data["lemmatizer_trees_full"] = all_trees
return data return data
@ -1013,7 +1010,8 @@ def _get_labels_from_trainable_lemmatizer(nlp: Language) -> Set[str]:
labels: Set[str] = set() labels: Set[str] = set()
for pipe_name in pipe_names: for pipe_name in pipe_names:
pipe = nlp.get_pipe(pipe_name) pipe = nlp.get_pipe(pipe_name)
for tree_id in range(len(pipe.trees)): # This gets lemmatization trees and no partial subtrees
for tree_id in pipe.tree2label.keys():
labels.add(pipe.trees.tree_to_str(tree_id)) labels.add(pipe.trees.tree_to_str(tree_id))
return labels return labels