Remove TokenAnnotation code from nonproj

This commit is contained in:
Matthew Honnibal 2020-06-15 18:14:47 +02:00
parent c95494739c
commit c66f93299e

View File

@ -7,7 +7,7 @@ from copy import copy
from ..tokens.doc cimport Doc, set_children_from_heads from ..tokens.doc cimport Doc, set_children_from_heads
from ..gold import Example, TokenAnnotation from ..gold import Example
from ..errors import Errors from ..errors import Errors
@ -90,31 +90,6 @@ def count_decorated_labels(gold_data):
return freqs return freqs
def preprocess_training_data(gold_data, label_freq_cutoff=30):
preprocessed = []
freqs = {}
for example in gold_data:
new_example = Example(doc=example.doc)
proj_heads, deco_deps = projectivize(example.token_annotation.heads,
example.token_annotation.deps)
# set the label to ROOT for each root dependent
deco_deps = ['ROOT' if head == i else deco_deps[i]
for i, head in enumerate(proj_heads)]
# count label frequencies
if label_freq_cutoff > 0:
for label in deco_deps:
if is_decorated(label):
freqs[label] = freqs.get(label, 0) + 1
proj_token_dict = example.token_annotation.to_dict()
proj_token_dict["heads"] = proj_heads
proj_token_dict["deps"] = deco_deps
new_example.token_annotation = TokenAnnotation(**proj_token_dict)
preprocessed.append(new_example)
if label_freq_cutoff > 0:
return _filter_labels(preprocessed, label_freq_cutoff, freqs)
return preprocessed
def projectivize(heads, labels): def projectivize(heads, labels):
# Use the algorithm by Nivre & Nilsson 2005. Assumes heads to be a proper # Use the algorithm by Nivre & Nilsson 2005. Assumes heads to be a proper
# tree, i.e. connected and cycle-free. Returns a new pair (heads, labels) # tree, i.e. connected and cycle-free. Returns a new pair (heads, labels)
@ -200,22 +175,3 @@ def _find_new_head(token, headlabel):
next_queue.append(child) next_queue.append(child)
queue = next_queue queue = next_queue
return token.head return token.head
def _filter_labels(examples, cutoff, freqs):
# throw away infrequent decorated labels
# can't learn them reliably anyway and keeps label set smaller
filtered = []
for example in examples:
new_example = Example(doc=example.doc)
filtered_labels = []
for label in example.token_annotation.deps:
if is_decorated(label) and freqs.get(label, 0) < cutoff:
filtered_labels.append(decompose(label)[0])
else:
filtered_labels.append(label)
filtered_token_dict = example.token_annotation.to_dict()
filtered_token_dict["deps"] = filtered_labels
new_example.token_annotation = TokenAnnotation(**filtered_token_dict)
filtered.append(new_example)
return filtered