mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Use updated JSON format, with sentences below paragraphs. Allows use of gold preprocessing flag.
This commit is contained in:
parent
2d11739f28
commit
76300bbb1b
|
@ -81,21 +81,21 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
|||
for itn in range(n_iter):
|
||||
scorer = Scorer()
|
||||
loss = 0
|
||||
for raw_text, annot_tuples, ctnt in gold_tuples:
|
||||
score_model(scorer, nlp, raw_text, annot_tuples)
|
||||
if raw_text is None:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
else:
|
||||
tokens = nlp.tokenizer(raw_text)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
nlp.tagger(tokens)
|
||||
try:
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
except AssertionError:
|
||||
# TODO: Do something about non-projective sentences
|
||||
pass
|
||||
nlp.entity.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
for raw_text, sents in gold_tuples:
|
||||
if not gold_preproc:
|
||||
sents = _merge_sents(sents)
|
||||
for annot_tuples, ctnt in sents:
|
||||
score_model(scorer, nlp, raw_text, annot_tuples)
|
||||
if raw_text is None or gold_preproc:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
else:
|
||||
tokens = nlp.tokenizer(raw_text)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
nlp.tagger(tokens)
|
||||
if gold.is_projective:
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
nlp.entity.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
random.shuffle(gold_tuples)
|
||||
print '%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
|
||||
scorer.tags_acc,
|
||||
|
@ -107,19 +107,21 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
|||
|
||||
|
||||
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=True):
|
||||
assert not gold_preproc
|
||||
nlp = Language(data_dir=model_dir)
|
||||
scorer = Scorer()
|
||||
for raw_text, annot_tuples, brackets in gold_tuples:
|
||||
if raw_text is not None:
|
||||
tokens = nlp(raw_text, merge_mwes=False)
|
||||
else:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
nlp.entity(tokens)
|
||||
nlp.parser(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
for raw_text, sents in gold_tuples:
|
||||
for annot_tuples, brackets in sents:
|
||||
if raw_text is None or gold_preproc:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
nlp.entity(tokens)
|
||||
nlp.parser(tokens)
|
||||
else:
|
||||
tokens = nlp(raw_text, merge_mwes=False)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
for t in tokens:
|
||||
print t.orth_, t.dep_, t.head.orth_, t.ent_type_
|
||||
return scorer
|
||||
|
||||
|
||||
|
@ -141,6 +143,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
|||
train_loc=("Location of training file or directory"),
|
||||
dev_loc=("Location of development file or directory"),
|
||||
corruption_level=("Amount of noise to add to training data", "option", "c", float),
|
||||
gold_preproc=("Use gold-standard sentence boundaries in training?", "flag", "g", bool),
|
||||
model_dir=("Location of output model directory",),
|
||||
out_loc=("Out location", "option", "o", str),
|
||||
n_sents=("Number of training sentences", "option", "n", int),
|
||||
|
@ -149,16 +152,16 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
|||
debug=("Debug mode", "flag", "d", bool)
|
||||
)
|
||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||
debug=False, corruption_level=0.0):
|
||||
debug=False, corruption_level=0.0, gold_preproc=False):
|
||||
gold_train = list(read_json_file(train_loc))
|
||||
train(English, gold_train, model_dir,
|
||||
feat_set='basic' if not debug else 'debug',
|
||||
gold_preproc=False, n_sents=n_sents,
|
||||
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||
corruption_level=corruption_level, n_iter=n_iter)
|
||||
if out_loc:
|
||||
write_parses(English, dev_loc, model_dir, out_loc)
|
||||
#if out_loc:
|
||||
# write_parses(English, dev_loc, model_dir, out_loc)
|
||||
scorer = evaluate(English, list(read_json_file(dev_loc)),
|
||||
model_dir, gold_preproc=False, verbose=verbose)
|
||||
model_dir, gold_preproc=gold_preproc, verbose=verbose)
|
||||
print 'TOK', 100-scorer.token_acc
|
||||
print 'POS', scorer.tags_acc
|
||||
print 'UAS', scorer.uas
|
||||
|
|
|
@ -104,24 +104,25 @@ def read_json_file(loc):
|
|||
for doc in ijson.items(file_, 'item'):
|
||||
paragraphs = []
|
||||
for paragraph in doc['paragraphs']:
|
||||
words = []
|
||||
ids = []
|
||||
tags = []
|
||||
heads = []
|
||||
labels = []
|
||||
ner = []
|
||||
for token in paragraph['tokens']:
|
||||
words.append(token['orth'])
|
||||
ids.append(token['id'])
|
||||
tags.append(token['tag'])
|
||||
heads.append(token['head'] if token['head'] >= 0 else token['id'])
|
||||
labels.append(token['dep'])
|
||||
ner.append(token.get('ner', '-'))
|
||||
|
||||
yield (
|
||||
paragraph.get('raw', None),
|
||||
(ids, words, tags, heads, labels, ner),
|
||||
paragraph.get('brackets', []))
|
||||
sents = []
|
||||
for sent in paragraph['sentences']:
|
||||
words = []
|
||||
ids = []
|
||||
tags = []
|
||||
heads = []
|
||||
labels = []
|
||||
ner = []
|
||||
for i, token in enumerate(sent['tokens']):
|
||||
words.append(token['orth'])
|
||||
ids.append(i)
|
||||
tags.append(token['tag'])
|
||||
heads.append(token['head'] + i)
|
||||
labels.append(token['dep'])
|
||||
ner.append(token.get('ner', '-'))
|
||||
sents.append((
|
||||
(ids, words, tags, heads, labels, ner),
|
||||
sent.get('brackets', [])))
|
||||
yield (paragraph.get('raw', None), sents)
|
||||
|
||||
|
||||
def _iob_to_biluo(tags):
|
||||
|
@ -203,6 +204,19 @@ cdef class GoldParse:
|
|||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
@property
|
||||
def is_projective(self):
|
||||
heads = [head for (id_, word, tag, head, dep, ner) in self.orig_annot]
|
||||
deps = sorted([sorted(arc) for arc in enumerate(heads)])
|
||||
for w1, h1 in deps:
|
||||
for w2, h2 in deps:
|
||||
if w1 < w2 < h1 < h2:
|
||||
return False
|
||||
elif w1 < w2 == h2 < h1:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_punct_label(label):
|
||||
return label == 'P' or label.lower() == 'punct'
|
||||
|
|
|
@ -54,15 +54,16 @@ cdef class ArcEager(TransitionSystem):
|
|||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True},
|
||||
CONSTITUENT: {}, ADJUST: {'': True}}
|
||||
for raw_text, (ids, words, tags, heads, labels, iob), ctnts in gold_parses:
|
||||
for child, head, label in zip(ids, heads, labels):
|
||||
if label != 'ROOT':
|
||||
if head < child:
|
||||
move_labels[RIGHT][label] = True
|
||||
elif head > child:
|
||||
move_labels[LEFT][label] = True
|
||||
for start, end, label in ctnts:
|
||||
move_labels[CONSTITUENT][label] = True
|
||||
for raw_text, sents in gold_parses:
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
for child, head, label in zip(ids, heads, labels):
|
||||
if label != 'ROOT':
|
||||
if head < child:
|
||||
move_labels[RIGHT][label] = True
|
||||
elif head > child:
|
||||
move_labels[LEFT][label] = True
|
||||
for start, end, label in ctnts:
|
||||
move_labels[CONSTITUENT][label] = True
|
||||
return move_labels
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
|
|
|
@ -73,15 +73,15 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {},
|
||||
OUT: {'': True}}
|
||||
moves = ('M', 'B', 'I', 'L', 'U')
|
||||
for (raw_text, tuples, ctnt) in gold_tuples:
|
||||
ids, words, tags, heads, labels, biluo = tuples
|
||||
for i, ner_tag in enumerate(biluo):
|
||||
if ner_tag != 'O' and ner_tag != '-':
|
||||
if ner_tag.count('-') != 1:
|
||||
raise ValueError(ner_tag)
|
||||
_, label = ner_tag.split('-')
|
||||
for move_str in ('B', 'I', 'L', 'U'):
|
||||
move_labels[moves.index(move_str)][label] = True
|
||||
for raw_text, sents in gold_tuples:
|
||||
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
||||
for i, ner_tag in enumerate(biluo):
|
||||
if ner_tag != 'O' and ner_tag != '-':
|
||||
if ner_tag.count('-') != 1:
|
||||
raise ValueError(ner_tag)
|
||||
_, label = ner_tag.split('-')
|
||||
for move_str in ('B', 'I', 'L', 'U'):
|
||||
move_labels[moves.index(move_str)][label] = True
|
||||
return move_labels
|
||||
|
||||
def move_name(self, int move, int label):
|
||||
|
|
Loading…
Reference in New Issue
Block a user