mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Update conllu script
This commit is contained in:
parent
9c8a0f6eba
commit
a26e399f84
|
@ -28,6 +28,8 @@ def prevent_bad_sentences(doc):
|
||||||
token.is_sent_start = False
|
token.is_sent_start = False
|
||||||
elif not token.nbor(-1).is_punct:
|
elif not token.nbor(-1).is_punct:
|
||||||
token.is_sent_start = False
|
token.is_sent_start = False
|
||||||
|
elif token.nbor(-1).is_left_punct:
|
||||||
|
token.is_sent_start = False
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +101,7 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||||
# cs is conllu sent, ct is conllu token
|
# cs is conllu sent, ct is conllu token
|
||||||
docs = []
|
docs = []
|
||||||
golds = []
|
golds = []
|
||||||
for text, cd in zip(paragraphs, conllu):
|
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
|
||||||
doc_words = []
|
doc_words = []
|
||||||
doc_tags = []
|
doc_tags = []
|
||||||
doc_heads = []
|
doc_heads = []
|
||||||
|
@ -140,7 +142,7 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||||
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
|
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
|
||||||
heads=doc_heads, deps=doc_deps,
|
heads=doc_heads, deps=doc_deps,
|
||||||
entities=doc_ents))
|
entities=doc_ents))
|
||||||
if limit and len(docs) >= limit:
|
if limit and doc_id >= limit:
|
||||||
break
|
break
|
||||||
return docs, golds
|
return docs, golds
|
||||||
|
|
||||||
|
@ -188,7 +190,14 @@ def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
||||||
with open(conllu_loc) as conllu_file:
|
with open(conllu_loc) as conllu_file:
|
||||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||||
oracle_segments=oracle_segments)
|
oracle_segments=oracle_segments)
|
||||||
if not joint_sbd:
|
if joint_sbd:
|
||||||
|
sbd = nlp.create_pipe('sentencizer')
|
||||||
|
for doc in docs:
|
||||||
|
doc = sbd(doc)
|
||||||
|
for sent in doc.sents:
|
||||||
|
sent[0].is_sent_start = True
|
||||||
|
#docs = (prevent_bad_sentences(doc) for doc in docs)
|
||||||
|
else:
|
||||||
sbd = nlp.create_pipe('sentencizer')
|
sbd = nlp.create_pipe('sentencizer')
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
doc = sbd(doc)
|
doc = sbd(doc)
|
||||||
|
@ -245,7 +254,8 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
||||||
with open(conllu_train_loc) as conllu_file:
|
with open(conllu_train_loc) as conllu_file:
|
||||||
with open(text_train_loc) as text_file:
|
with open(text_train_loc) as text_file:
|
||||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||||
oracle_segments=False, raw_text=True)
|
oracle_segments=True, raw_text=True,
|
||||||
|
limit=None)
|
||||||
print("Create parser")
|
print("Create parser")
|
||||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||||
|
@ -266,7 +276,7 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
||||||
print("Begin training")
|
print("Begin training")
|
||||||
# Batch size starts at 1 and grows, so that we make updates quickly
|
# Batch size starts at 1 and grows, so that we make updates quickly
|
||||||
# at the beginning of training.
|
# at the beginning of training.
|
||||||
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 2),
|
||||||
spacy.util.env_opt('batch_to', 2),
|
spacy.util.env_opt('batch_to', 2),
|
||||||
spacy.util.env_opt('batch_compound', 1.001))
|
spacy.util.env_opt('batch_compound', 1.001))
|
||||||
for i in range(30):
|
for i in range(30):
|
||||||
|
@ -278,6 +288,7 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
||||||
if not batch:
|
if not batch:
|
||||||
continue
|
continue
|
||||||
batch_docs, batch_gold = zip(*batch)
|
batch_docs, batch_gold = zip(*batch)
|
||||||
|
batch_docs = [prevent_bad_sentences(doc) for doc in batch_docs]
|
||||||
|
|
||||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||||
drop=0.2, losses=losses)
|
drop=0.2, losses=losses)
|
||||||
|
@ -296,6 +307,5 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
||||||
print_conllu(dev_docs, file_)
|
print_conllu(dev_docs, file_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user