Fix conllu script (#4579)

* force extensions to avoid clash between example scripts

* fix arg order and default file encoding

* add example config for conllu script

* newline

* move extension definitions to main function

* few more encodings fixes
This commit is contained in:
Sofie Van Landeghem 2019-11-04 20:31:26 +01:00 committed by Ines Montani
parent 4e43c0ba93
commit 4ec7623288
3 changed files with 20 additions and 25 deletions

View File

@ -7,7 +7,6 @@ from __future__ import unicode_literals
import plac import plac
from pathlib import Path from pathlib import Path
import re import re
import sys
import json import json
import spacy import spacy
@ -19,12 +18,9 @@ from spacy.util import compounding, minibatch, minibatch_by_words
from spacy.syntax.nonproj import projectivize from spacy.syntax.nonproj import projectivize
from spacy.matcher import Matcher from spacy.matcher import Matcher
from spacy import displacy from spacy import displacy
from collections import defaultdict, Counter from collections import defaultdict
from timeit import default_timer as timer
import itertools
import random import random
import numpy.random
from spacy import lang from spacy import lang
from spacy.lang import zh from spacy.lang import zh
@ -323,10 +319,6 @@ def get_token_conllu(token, i):
return "\n".join(lines) return "\n".join(lines)
Token.set_extension("get_conllu_lines", method=get_token_conllu, force=True)
Token.set_extension("begins_fused", default=False, force=True)
Token.set_extension("inside_fused", default=False, force=True)
################## ##################
# Initialization # # Initialization #
@ -459,13 +451,13 @@ class TreebankPaths(object):
@plac.annotations( @plac.annotations(
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path), ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
parses_dir=("Directory to write the development parses", "positional", None, Path),
corpus=( corpus=(
"UD corpus to train and evaluate on, e.g. en, es_ancora, etc", "UD corpus to train and evaluate on, e.g. UD_Spanish-AnCora",
"positional", "positional",
None, None,
str, str,
), ),
parses_dir=("Directory to write the development parses", "positional", None, Path),
config=("Path to json formatted config file", "option", "C", Path), config=("Path to json formatted config file", "option", "C", Path),
limit=("Size limit", "option", "n", int), limit=("Size limit", "option", "n", int),
gpu_device=("Use GPU", "option", "g", int), gpu_device=("Use GPU", "option", "g", int),
@ -490,6 +482,10 @@ def main(
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200 # temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm import tqdm
Token.set_extension("get_conllu_lines", method=get_token_conllu)
Token.set_extension("begins_fused", default=False)
Token.set_extension("inside_fused", default=False)
spacy.util.fix_random_seed() spacy.util.fix_random_seed()
lang.zh.Chinese.Defaults.use_jieba = False lang.zh.Chinese.Defaults.use_jieba = False
lang.ja.Japanese.Defaults.use_janome = False lang.ja.Japanese.Defaults.use_janome = False
@ -506,8 +502,8 @@ def main(
docs, golds = read_data( docs, golds = read_data(
nlp, nlp,
paths.train.conllu.open(), paths.train.conllu.open(encoding="utf8"),
paths.train.text.open(), paths.train.text.open(encoding="utf8"),
max_doc_length=config.max_doc_length, max_doc_length=config.max_doc_length,
limit=limit, limit=limit,
) )

View File

@ -0,0 +1 @@
{"nr_epoch": 3, "batch_size": 24, "dropout": 0.001, "vectors": 0, "multitask_tag": 0, "multitask_sent": 0}

View File

@ -13,8 +13,7 @@ import spacy.util
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.syntax.nonproj import projectivize from spacy.syntax.nonproj import projectivize
from collections import defaultdict, Counter from collections import defaultdict
from timeit import default_timer as timer
from spacy.matcher import Matcher from spacy.matcher import Matcher
import itertools import itertools
@ -290,11 +289,6 @@ def get_token_conllu(token, i):
return "\n".join(lines) return "\n".join(lines)
Token.set_extension("get_conllu_lines", method=get_token_conllu)
Token.set_extension("begins_fused", default=False)
Token.set_extension("inside_fused", default=False)
################## ##################
# Initialization # # Initialization #
################## ##################
@ -381,20 +375,24 @@ class TreebankPaths(object):
@plac.annotations( @plac.annotations(
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path), ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
parses_dir=("Directory to write the development parses", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
corpus=( corpus=(
"UD corpus to train and evaluate on, e.g. en, es_ancora, etc", "UD corpus to train and evaluate on, e.g. UD_Spanish-AnCora",
"positional", "positional",
None, None,
str, str,
), ),
parses_dir=("Directory to write the development parses", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
limit=("Size limit", "option", "n", int), limit=("Size limit", "option", "n", int),
) )
def main(ud_dir, parses_dir, config, corpus, limit=0): def main(ud_dir, parses_dir, config, corpus, limit=0):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200 # temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm import tqdm
Token.set_extension("get_conllu_lines", method=get_token_conllu)
Token.set_extension("begins_fused", default=False)
Token.set_extension("inside_fused", default=False)
paths = TreebankPaths(ud_dir, corpus) paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists(): if not (parses_dir / corpus).exists():
(parses_dir / corpus).mkdir() (parses_dir / corpus).mkdir()
@ -403,8 +401,8 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
docs, golds = read_data( docs, golds = read_data(
nlp, nlp,
paths.train.conllu.open(), paths.train.conllu.open(encoding="utf8"),
paths.train.text.open(), paths.train.text.open(encoding="utf8"),
max_doc_length=config.max_doc_length, max_doc_length=config.max_doc_length,
limit=limit, limit=limit,
) )