mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
* Use updated JSON format, with sentences below paragraphs. Allows use of gold preprocessing flag.
This commit is contained in:
parent
2d11739f28
commit
76300bbb1b
|
@ -81,21 +81,21 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
loss = 0
|
loss = 0
|
||||||
for raw_text, annot_tuples, ctnt in gold_tuples:
|
for raw_text, sents in gold_tuples:
|
||||||
score_model(scorer, nlp, raw_text, annot_tuples)
|
if not gold_preproc:
|
||||||
if raw_text is None:
|
sents = _merge_sents(sents)
|
||||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
for annot_tuples, ctnt in sents:
|
||||||
else:
|
score_model(scorer, nlp, raw_text, annot_tuples)
|
||||||
tokens = nlp.tokenizer(raw_text)
|
if raw_text is None or gold_preproc:
|
||||||
gold = GoldParse(tokens, annot_tuples)
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||||
nlp.tagger(tokens)
|
else:
|
||||||
try:
|
tokens = nlp.tokenizer(raw_text)
|
||||||
loss += nlp.parser.train(tokens, gold)
|
gold = GoldParse(tokens, annot_tuples)
|
||||||
except AssertionError:
|
nlp.tagger(tokens)
|
||||||
# TODO: Do something about non-projective sentences
|
if gold.is_projective:
|
||||||
pass
|
loss += nlp.parser.train(tokens, gold)
|
||||||
nlp.entity.train(tokens, gold)
|
nlp.entity.train(tokens, gold)
|
||||||
nlp.tagger.train(tokens, gold.tags)
|
nlp.tagger.train(tokens, gold.tags)
|
||||||
random.shuffle(gold_tuples)
|
random.shuffle(gold_tuples)
|
||||||
print '%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
|
print '%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
|
||||||
scorer.tags_acc,
|
scorer.tags_acc,
|
||||||
|
@ -107,19 +107,21 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
||||||
|
|
||||||
|
|
||||||
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=True):
|
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=True):
|
||||||
assert not gold_preproc
|
|
||||||
nlp = Language(data_dir=model_dir)
|
nlp = Language(data_dir=model_dir)
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
for raw_text, annot_tuples, brackets in gold_tuples:
|
for raw_text, sents in gold_tuples:
|
||||||
if raw_text is not None:
|
for annot_tuples, brackets in sents:
|
||||||
tokens = nlp(raw_text, merge_mwes=False)
|
if raw_text is None or gold_preproc:
|
||||||
else:
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
nlp.tagger(tokens)
|
||||||
nlp.tagger(tokens)
|
nlp.entity(tokens)
|
||||||
nlp.entity(tokens)
|
nlp.parser(tokens)
|
||||||
nlp.parser(tokens)
|
else:
|
||||||
gold = GoldParse(tokens, annot_tuples)
|
tokens = nlp(raw_text, merge_mwes=False)
|
||||||
scorer.score(tokens, gold, verbose=verbose)
|
gold = GoldParse(tokens, annot_tuples)
|
||||||
|
scorer.score(tokens, gold, verbose=verbose)
|
||||||
|
for t in tokens:
|
||||||
|
print t.orth_, t.dep_, t.head.orth_, t.ent_type_
|
||||||
return scorer
|
return scorer
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,6 +143,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||||
train_loc=("Location of training file or directory"),
|
train_loc=("Location of training file or directory"),
|
||||||
dev_loc=("Location of development file or directory"),
|
dev_loc=("Location of development file or directory"),
|
||||||
corruption_level=("Amount of noise to add to training data", "option", "c", float),
|
corruption_level=("Amount of noise to add to training data", "option", "c", float),
|
||||||
|
gold_preproc=("Use gold-standard sentence boundaries in training?", "flag", "g", bool),
|
||||||
model_dir=("Location of output model directory",),
|
model_dir=("Location of output model directory",),
|
||||||
out_loc=("Out location", "option", "o", str),
|
out_loc=("Out location", "option", "o", str),
|
||||||
n_sents=("Number of training sentences", "option", "n", int),
|
n_sents=("Number of training sentences", "option", "n", int),
|
||||||
|
@ -149,16 +152,16 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||||
debug=("Debug mode", "flag", "d", bool)
|
debug=("Debug mode", "flag", "d", bool)
|
||||||
)
|
)
|
||||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||||
debug=False, corruption_level=0.0):
|
debug=False, corruption_level=0.0, gold_preproc=False):
|
||||||
gold_train = list(read_json_file(train_loc))
|
gold_train = list(read_json_file(train_loc))
|
||||||
train(English, gold_train, model_dir,
|
train(English, gold_train, model_dir,
|
||||||
feat_set='basic' if not debug else 'debug',
|
feat_set='basic' if not debug else 'debug',
|
||||||
gold_preproc=False, n_sents=n_sents,
|
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||||
corruption_level=corruption_level, n_iter=n_iter)
|
corruption_level=corruption_level, n_iter=n_iter)
|
||||||
if out_loc:
|
#if out_loc:
|
||||||
write_parses(English, dev_loc, model_dir, out_loc)
|
# write_parses(English, dev_loc, model_dir, out_loc)
|
||||||
scorer = evaluate(English, list(read_json_file(dev_loc)),
|
scorer = evaluate(English, list(read_json_file(dev_loc)),
|
||||||
model_dir, gold_preproc=False, verbose=verbose)
|
model_dir, gold_preproc=gold_preproc, verbose=verbose)
|
||||||
print 'TOK', 100-scorer.token_acc
|
print 'TOK', 100-scorer.token_acc
|
||||||
print 'POS', scorer.tags_acc
|
print 'POS', scorer.tags_acc
|
||||||
print 'UAS', scorer.uas
|
print 'UAS', scorer.uas
|
||||||
|
|
|
@ -104,24 +104,25 @@ def read_json_file(loc):
|
||||||
for doc in ijson.items(file_, 'item'):
|
for doc in ijson.items(file_, 'item'):
|
||||||
paragraphs = []
|
paragraphs = []
|
||||||
for paragraph in doc['paragraphs']:
|
for paragraph in doc['paragraphs']:
|
||||||
words = []
|
sents = []
|
||||||
ids = []
|
for sent in paragraph['sentences']:
|
||||||
tags = []
|
words = []
|
||||||
heads = []
|
ids = []
|
||||||
labels = []
|
tags = []
|
||||||
ner = []
|
heads = []
|
||||||
for token in paragraph['tokens']:
|
labels = []
|
||||||
words.append(token['orth'])
|
ner = []
|
||||||
ids.append(token['id'])
|
for i, token in enumerate(sent['tokens']):
|
||||||
tags.append(token['tag'])
|
words.append(token['orth'])
|
||||||
heads.append(token['head'] if token['head'] >= 0 else token['id'])
|
ids.append(i)
|
||||||
labels.append(token['dep'])
|
tags.append(token['tag'])
|
||||||
ner.append(token.get('ner', '-'))
|
heads.append(token['head'] + i)
|
||||||
|
labels.append(token['dep'])
|
||||||
yield (
|
ner.append(token.get('ner', '-'))
|
||||||
paragraph.get('raw', None),
|
sents.append((
|
||||||
(ids, words, tags, heads, labels, ner),
|
(ids, words, tags, heads, labels, ner),
|
||||||
paragraph.get('brackets', []))
|
sent.get('brackets', [])))
|
||||||
|
yield (paragraph.get('raw', None), sents)
|
||||||
|
|
||||||
|
|
||||||
def _iob_to_biluo(tags):
|
def _iob_to_biluo(tags):
|
||||||
|
@ -203,6 +204,19 @@ cdef class GoldParse:
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_projective(self):
|
||||||
|
heads = [head for (id_, word, tag, head, dep, ner) in self.orig_annot]
|
||||||
|
deps = sorted([sorted(arc) for arc in enumerate(heads)])
|
||||||
|
for w1, h1 in deps:
|
||||||
|
for w2, h2 in deps:
|
||||||
|
if w1 < w2 < h1 < h2:
|
||||||
|
return False
|
||||||
|
elif w1 < w2 == h2 < h1:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_punct_label(label):
|
def is_punct_label(label):
|
||||||
return label == 'P' or label.lower() == 'punct'
|
return label == 'P' or label.lower() == 'punct'
|
||||||
|
|
|
@ -54,15 +54,16 @@ cdef class ArcEager(TransitionSystem):
|
||||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
||||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True},
|
LEFT: {'ROOT': True}, BREAK: {'ROOT': True},
|
||||||
CONSTITUENT: {}, ADJUST: {'': True}}
|
CONSTITUENT: {}, ADJUST: {'': True}}
|
||||||
for raw_text, (ids, words, tags, heads, labels, iob), ctnts in gold_parses:
|
for raw_text, sents in gold_parses:
|
||||||
for child, head, label in zip(ids, heads, labels):
|
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||||
if label != 'ROOT':
|
for child, head, label in zip(ids, heads, labels):
|
||||||
if head < child:
|
if label != 'ROOT':
|
||||||
move_labels[RIGHT][label] = True
|
if head < child:
|
||||||
elif head > child:
|
move_labels[RIGHT][label] = True
|
||||||
move_labels[LEFT][label] = True
|
elif head > child:
|
||||||
for start, end, label in ctnts:
|
move_labels[LEFT][label] = True
|
||||||
move_labels[CONSTITUENT][label] = True
|
for start, end, label in ctnts:
|
||||||
|
move_labels[CONSTITUENT][label] = True
|
||||||
return move_labels
|
return move_labels
|
||||||
|
|
||||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||||
|
|
|
@ -73,15 +73,15 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {},
|
move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {},
|
||||||
OUT: {'': True}}
|
OUT: {'': True}}
|
||||||
moves = ('M', 'B', 'I', 'L', 'U')
|
moves = ('M', 'B', 'I', 'L', 'U')
|
||||||
for (raw_text, tuples, ctnt) in gold_tuples:
|
for raw_text, sents in gold_tuples:
|
||||||
ids, words, tags, heads, labels, biluo = tuples
|
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
||||||
for i, ner_tag in enumerate(biluo):
|
for i, ner_tag in enumerate(biluo):
|
||||||
if ner_tag != 'O' and ner_tag != '-':
|
if ner_tag != 'O' and ner_tag != '-':
|
||||||
if ner_tag.count('-') != 1:
|
if ner_tag.count('-') != 1:
|
||||||
raise ValueError(ner_tag)
|
raise ValueError(ner_tag)
|
||||||
_, label = ner_tag.split('-')
|
_, label = ner_tag.split('-')
|
||||||
for move_str in ('B', 'I', 'L', 'U'):
|
for move_str in ('B', 'I', 'L', 'U'):
|
||||||
move_labels[moves.index(move_str)][label] = True
|
move_labels[moves.index(move_str)][label] = True
|
||||||
return move_labels
|
return move_labels
|
||||||
|
|
||||||
def move_name(self, int move, int label):
|
def move_name(self, int move, int label):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user