From 025d9bbc3782e6e8a1a9680db944c913df12610d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 May 2017 18:44:07 -0500 Subject: [PATCH] Fix handling of non-projective deps --- spacy/gold.pyx | 11 +++++++---- spacy/syntax/nn_parser.pyx | 4 +--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 45b96a159..7d8e44f79 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -168,10 +168,14 @@ class GoldCorpus(object): n += 1 return n - def train_docs(self, nlp, shuffle=0, gold_preproc=True): + def train_docs(self, nlp, shuffle=0, gold_preproc=True, + projectivize=False): if shuffle: random.shuffle(self.train_locs) - gold_docs = self.iter_gold_docs(nlp, self.train_tuples, gold_preproc) + if projectivize: + train_tuples = nonproj.PseudoProjectivity.preprocess_training_data( + self.train_tuples) + gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc) if shuffle: gold_docs = util.itershuffle(gold_docs, bufsize=shuffle*1000) gold_docs = nlp.preprocess_gold(gold_docs) @@ -184,7 +188,6 @@ class GoldCorpus(object): @classmethod def iter_gold_docs(cls, nlp, tuples, gold_preproc=True): - tuples = nonproj.PseudoProjectivity.preprocess_training_data(tuples) for raw_text, paragraph_tuples in tuples: docs = cls._make_docs(nlp, raw_text, paragraph_tuples, gold_preproc) @@ -233,7 +236,7 @@ class GoldCorpus(object): return locs -def read_json_file(loc, docs_filter=None, limit=None): +def read_json_file(loc, docs_filter=None, limit=1000): loc = ensure_path(loc) if loc.is_dir(): for filename in loc.iterdir(): diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index fb029cfe9..6cd2fea95 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -330,7 +330,7 @@ cdef class Parser: backprops = [] cdef float loss = 0. - while todo: + while len(todo) >= 3: states, golds = zip(*todo) token_ids = self.get_token_ids(states) @@ -445,8 +445,6 @@ cdef class Parser: def preprocess_gold(self, docs_golds): for doc, gold in docs_golds: - gold.heads, gold.labels = PseudoProjectivity.projectivize( - gold.heads, gold.labels) yield doc, gold def use_params(self, params):