Fix handling of non-projective deps

This commit is contained in:
Matthew Honnibal 2017-05-21 18:44:07 -05:00
parent 5738d373d5
commit 025d9bbc37
2 changed files with 8 additions and 7 deletions

View File

@ -168,10 +168,14 @@ class GoldCorpus(object):
n += 1 n += 1
return n return n
def train_docs(self, nlp, shuffle=0, gold_preproc=True): def train_docs(self, nlp, shuffle=0, gold_preproc=True,
projectivize=False):
if shuffle: if shuffle:
random.shuffle(self.train_locs) random.shuffle(self.train_locs)
gold_docs = self.iter_gold_docs(nlp, self.train_tuples, gold_preproc) if projectivize:
train_tuples = nonproj.PseudoProjectivity.preprocess_training_data(
self.train_tuples)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc)
if shuffle: if shuffle:
gold_docs = util.itershuffle(gold_docs, bufsize=shuffle*1000) gold_docs = util.itershuffle(gold_docs, bufsize=shuffle*1000)
gold_docs = nlp.preprocess_gold(gold_docs) gold_docs = nlp.preprocess_gold(gold_docs)
@ -184,7 +188,6 @@ class GoldCorpus(object):
@classmethod @classmethod
def iter_gold_docs(cls, nlp, tuples, gold_preproc=True): def iter_gold_docs(cls, nlp, tuples, gold_preproc=True):
tuples = nonproj.PseudoProjectivity.preprocess_training_data(tuples)
for raw_text, paragraph_tuples in tuples: for raw_text, paragraph_tuples in tuples:
docs = cls._make_docs(nlp, raw_text, paragraph_tuples, docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
gold_preproc) gold_preproc)
@ -233,7 +236,7 @@ class GoldCorpus(object):
return locs return locs
def read_json_file(loc, docs_filter=None, limit=None): def read_json_file(loc, docs_filter=None, limit=1000):
loc = ensure_path(loc) loc = ensure_path(loc)
if loc.is_dir(): if loc.is_dir():
for filename in loc.iterdir(): for filename in loc.iterdir():

View File

@ -330,7 +330,7 @@ cdef class Parser:
backprops = [] backprops = []
cdef float loss = 0. cdef float loss = 0.
while todo: while len(todo) >= 3:
states, golds = zip(*todo) states, golds = zip(*todo)
token_ids = self.get_token_ids(states) token_ids = self.get_token_ids(states)
@ -445,8 +445,6 @@ cdef class Parser:
def preprocess_gold(self, docs_golds): def preprocess_gold(self, docs_golds):
for doc, gold in docs_golds: for doc, gold in docs_golds:
gold.heads, gold.labels = PseudoProjectivity.projectivize(
gold.heads, gold.labels)
yield doc, gold yield doc, gold
def use_params(self, params): def use_params(self, params):