diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py index 86fdabca4..7048d748b 100644 --- a/spacy/cli/ud_train.py +++ b/spacy/cli/ud_train.py @@ -124,13 +124,16 @@ def read_conllu(file_): return docs -def _make_gold(nlp, text, sent_annots): +def _make_gold(nlp, text, sent_annots, drop_deps=0.0): # Flatten the conll annotations, and adjust the head indices flat = defaultdict(list) + sent_starts = [] for sent in sent_annots: flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) for field in ['words', 'tags', 'deps', 'entities', 'spaces']: flat[field].extend(sent[field]) + sent_starts.append(True) + sent_starts.extend([False] * (len(sent['words'])-1)) # Construct text if necessary assert len(flat['words']) == len(flat['spaces']) if text is None: @@ -138,6 +141,12 @@ def _make_gold(nlp, text, sent_annots): doc = nlp.make_doc(text) flat.pop('spaces') gold = GoldParse(doc, **flat) + gold.sent_starts = sent_starts + for i in range(len(gold.heads)): + if random.random() < drop_deps: + gold.heads[i] = None + gold.labels[i] = None + return doc, gold #############################