mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-10 08:03:15 +03:00
Set gold.sent_starts in ud_train
This commit is contained in:
parent
bf19f22340
commit
eddc0e0c74
|
@ -124,13 +124,16 @@ def read_conllu(file_):
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def _make_gold(nlp, text, sent_annots):
|
def _make_gold(nlp, text, sent_annots, drop_deps=0.0):
|
||||||
# Flatten the conll annotations, and adjust the head indices
|
# Flatten the conll annotations, and adjust the head indices
|
||||||
flat = defaultdict(list)
|
flat = defaultdict(list)
|
||||||
|
sent_starts = []
|
||||||
for sent in sent_annots:
|
for sent in sent_annots:
|
||||||
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
||||||
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
||||||
flat[field].extend(sent[field])
|
flat[field].extend(sent[field])
|
||||||
|
sent_starts.append(True)
|
||||||
|
sent_starts.extend([False] * (len(sent['words'])-1))
|
||||||
# Construct text if necessary
|
# Construct text if necessary
|
||||||
assert len(flat['words']) == len(flat['spaces'])
|
assert len(flat['words']) == len(flat['spaces'])
|
||||||
if text is None:
|
if text is None:
|
||||||
|
@ -138,6 +141,12 @@ def _make_gold(nlp, text, sent_annots):
|
||||||
doc = nlp.make_doc(text)
|
doc = nlp.make_doc(text)
|
||||||
flat.pop('spaces')
|
flat.pop('spaces')
|
||||||
gold = GoldParse(doc, **flat)
|
gold = GoldParse(doc, **flat)
|
||||||
|
gold.sent_starts = sent_starts
|
||||||
|
for i in range(len(gold.heads)):
|
||||||
|
if random.random() < drop_deps:
|
||||||
|
gold.heads[i] = None
|
||||||
|
gold.labels[i] = None
|
||||||
|
|
||||||
return doc, gold
|
return doc, gold
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
|
|
Loading…
Reference in New Issue
Block a user