mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 02:16:32 +03:00
* Allow parsers and taggers to be trained on text without gold pre-processing.
This commit is contained in:
parent
67d6e53a69
commit
ca7577d8a9
|
@ -9,6 +9,7 @@ import codecs
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import gzip
|
import gzip
|
||||||
|
import nltk
|
||||||
|
|
||||||
import plac
|
import plac
|
||||||
import cProfile
|
import cProfile
|
||||||
|
@ -22,6 +23,10 @@ from spacy.syntax.parser import GreedyParser
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
|
|
||||||
|
|
||||||
|
def is_punct_label(label):
|
||||||
|
return label == 'P' or label.lower() == 'punct'
|
||||||
|
|
||||||
|
|
||||||
def read_tokenized_gold(file_):
|
def read_tokenized_gold(file_):
|
||||||
"""Read a standard CoNLL/MALT-style format"""
|
"""Read a standard CoNLL/MALT-style format"""
|
||||||
sents = []
|
sents = []
|
||||||
|
@ -96,20 +101,20 @@ def _parse_line(line):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
loss = 0
|
||||||
def _align_annotations_to_non_gold_tokens(tokens, words, annot):
|
def _align_annotations_to_non_gold_tokens(tokens, words, annot):
|
||||||
|
global loss
|
||||||
tags = []
|
tags = []
|
||||||
heads = []
|
heads = []
|
||||||
labels = []
|
labels = []
|
||||||
loss = 0
|
orig_words = list(words)
|
||||||
print [t.orth_ for t in tokens]
|
missed = []
|
||||||
print words
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
print token.orth_, words[0]
|
|
||||||
print token.idx, annot[0][0]
|
|
||||||
while annot and token.idx > annot[0][0]:
|
while annot and token.idx > annot[0][0]:
|
||||||
print 'pop', token.idx, annot[0][0]
|
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
|
||||||
annot.pop(0)
|
miss_w = words.pop(0)
|
||||||
words.pop(0)
|
if not is_punct_label(miss_label):
|
||||||
|
missed.append(miss_w)
|
||||||
loss += 1
|
loss += 1
|
||||||
if not annot:
|
if not annot:
|
||||||
tags.append(None)
|
tags.append(None)
|
||||||
|
@ -129,6 +134,11 @@ def _align_annotations_to_non_gold_tokens(tokens, words, annot):
|
||||||
labels.append(None)
|
labels.append(None)
|
||||||
else:
|
else:
|
||||||
raise StandardError
|
raise StandardError
|
||||||
|
#if missed:
|
||||||
|
# print orig_words
|
||||||
|
# print missed
|
||||||
|
# for t in tokens:
|
||||||
|
# print t.idx, t.orth_
|
||||||
return loss, tags, heads, labels
|
return loss, tags, heads, labels
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,7 +147,8 @@ def iter_data(paragraphs, tokenizer, gold_preproc=False):
|
||||||
if not gold_preproc:
|
if not gold_preproc:
|
||||||
tokens = tokenizer(raw)
|
tokens = tokenizer(raw)
|
||||||
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
|
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
|
||||||
tokens, words, zip(ids, tags, heads, labels))
|
tokens, list(words),
|
||||||
|
zip(ids, tags, heads, labels))
|
||||||
ids = [t.idx for t in tokens]
|
ids = [t.idx for t in tokens]
|
||||||
heads = _map_indices_to_tokens(ids, heads)
|
heads = _map_indices_to_tokens(ids, heads)
|
||||||
yield tokens, tags, heads, labels
|
yield tokens, tags, heads, labels
|
||||||
|
@ -170,7 +181,7 @@ def get_labels(sents):
|
||||||
|
|
||||||
|
|
||||||
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
gold_preproc=True):
|
gold_preproc=False):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
if path.exists(dep_model_dir):
|
if path.exists(dep_model_dir):
|
||||||
|
@ -194,10 +205,9 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
||||||
gold_preproc=gold_preproc):
|
gold_preproc=gold_preproc):
|
||||||
tags = [nlp.tagger.tag_names.index(tag) for tag in tag_strs]
|
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
|
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
|
||||||
pos_corr += nlp.tagger.train(tokens, tags)
|
pos_corr += nlp.tagger.train(tokens, tag_strs)
|
||||||
n_tokens += len(tokens)
|
n_tokens += len(tokens)
|
||||||
acc = float(heads_corr) / n_tokens
|
acc = float(heads_corr) / n_tokens
|
||||||
pos_acc = float(pos_corr) / n_tokens
|
pos_acc = float(pos_corr) / n_tokens
|
||||||
|
@ -223,12 +233,13 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
if heads[i] is None:
|
if heads[i] is None:
|
||||||
skipped += 1
|
skipped += 1
|
||||||
if labels[i] == 'P' or labels[i] == 'punct':
|
continue
|
||||||
|
if is_punct_label(labels[i]):
|
||||||
continue
|
continue
|
||||||
n_corr += token.head.i == heads[i]
|
n_corr += token.head.i == heads[i]
|
||||||
total += 1
|
total += 1
|
||||||
print skipped
|
print loss, skipped, (loss+skipped + total)
|
||||||
return float(n_corr) / total
|
return float(n_corr) / (total + loss)
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user