Merge pull request #280 from wbwseeker/german_parser

German parser
This commit is contained in:
Matthew Honnibal 2016-03-04 03:27:42 +11:00
commit fcaa0ad7ce
20 changed files with 1687 additions and 134 deletions

View File

@ -98,7 +98,7 @@ def _read_probs(loc):
return probs, probs['-OOV-'] return probs, probs['-OOV-']
def _read_freqs(loc, max_length=100, min_doc_freq=5, min_freq=200): def _read_freqs(loc, max_length=100, min_doc_freq=0, min_freq=200):
if not loc.exists(): if not loc.exists():
print("Warning: Frequencies file not found") print("Warning: Frequencies file not found")
return {}, 0.0 return {}, 0.0
@ -125,7 +125,8 @@ def _read_freqs(loc, max_length=100, min_doc_freq=5, min_freq=200):
doc_freq = int(doc_freq) doc_freq = int(doc_freq)
freq = int(freq) freq = int(freq)
if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length: if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length:
word = literal_eval(key) # word = literal_eval(key)
word = key
smooth_count = counts.smoother(int(freq)) smooth_count = counts.smoother(int(freq))
log_smooth_count = math.log(smooth_count) log_smooth_count = math.log(smooth_count)
probs[word] = math.log(smooth_count) - log_total probs[word] = math.log(smooth_count) - log_total
@ -165,7 +166,7 @@ def setup_vocab(get_lex_attr, tag_map, src_dir, dst_dir):
clusters = _read_clusters(src_dir / 'clusters.txt') clusters = _read_clusters(src_dir / 'clusters.txt')
probs, oov_prob = _read_probs(src_dir / 'words.sgt.prob') probs, oov_prob = _read_probs(src_dir / 'words.sgt.prob')
if not probs: if not probs:
probs, oov_prob = _read_freqs(src_dir / 'freqs.txt.gz') probs, oov_prob = _read_freqs(src_dir / 'freqs.txt')
if not probs: if not probs:
oov_prob = -20 oov_prob = -20
else: else:
@ -223,7 +224,6 @@ def main(lang_id, lang_data_dir, corpora_dir, model_dir):
copyfile(str(lang_data_dir / 'gazetteer.json'), copyfile(str(lang_data_dir / 'gazetteer.json'),
str(model_dir / 'vocab' / 'gazetteer.json')) str(model_dir / 'vocab' / 'gazetteer.json'))
if (lang_data_dir / 'tag_map.json').exists():
copyfile(str(lang_data_dir / 'tag_map.json'), copyfile(str(lang_data_dir / 'tag_map.json'),
str(model_dir / 'vocab' / 'tag_map.json')) str(model_dir / 'vocab' / 'tag_map.json'))

View File

@ -14,6 +14,7 @@ import re
import spacy.util import spacy.util
from spacy.en import English from spacy.en import English
from spacy.de import German
from spacy.syntax.util import Config from spacy.syntax.util import Config
from spacy.gold import read_json_file from spacy.gold import read_json_file
@ -25,6 +26,7 @@ from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.ner import BiluoPushDown from spacy.syntax.ner import BiluoPushDown
from spacy.tagger import Tagger from spacy.tagger import Tagger
from spacy.syntax.parser import Parser from spacy.syntax.parser import Parser
from spacy.syntax.nonproj import PseudoProjectivity
def _corrupt(c, noise_level): def _corrupt(c, noise_level):
@ -82,7 +84,7 @@ def _merge_sents(sents):
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
seed=0, gold_preproc=False, n_sents=0, corruption_level=0, seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
beam_width=1, verbose=False, beam_width=1, verbose=False,
use_orig_arc_eager=False): use_orig_arc_eager=False, pseudoprojective=False):
dep_model_dir = path.join(model_dir, 'deps') dep_model_dir = path.join(model_dir, 'deps')
ner_model_dir = path.join(model_dir, 'ner') ner_model_dir = path.join(model_dir, 'ner')
pos_model_dir = path.join(model_dir, 'pos') pos_model_dir = path.join(model_dir, 'pos')
@ -96,9 +98,13 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
os.mkdir(ner_model_dir) os.mkdir(ner_model_dir)
os.mkdir(pos_model_dir) os.mkdir(pos_model_dir)
if pseudoprojective:
# preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=ArcEager.get_labels(gold_tuples), labels=ArcEager.get_labels(gold_tuples),
beam_width=beam_width) beam_width=beam_width,projectivize=pseudoprojective)
Config.write(ner_model_dir, 'config', features='ner', seed=seed, Config.write(ner_model_dir, 'config', features='ner', seed=seed,
labels=BiluoPushDown.get_labels(gold_tuples), labels=BiluoPushDown.get_labels(gold_tuples),
beam_width=0) beam_width=0)
@ -107,6 +113,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
gold_tuples = gold_tuples[:n_sents] gold_tuples = gold_tuples[:n_sents]
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False) nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
if nlp.lang == 'de':
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates()) nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager) nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown) nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
@ -131,12 +139,9 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
raw_text = add_noise(raw_text, corruption_level) raw_text = add_noise(raw_text, corruption_level)
tokens = nlp.tokenizer(raw_text) tokens = nlp.tokenizer(raw_text)
nlp.tagger(tokens) nlp.tagger(tokens)
gold = GoldParse(tokens, annot_tuples, make_projective=True) gold = GoldParse(tokens, annot_tuples)
if not gold.is_projective: if not gold.is_projective:
raise Exception( raise Exception("Non-projective sentence in training: %s" % annot_tuples)
"Non-projective sentence in training, after we should "
"have enforced projectivity: %s" % annot_tuples
)
loss += nlp.parser.train(tokens, gold) loss += nlp.parser.train(tokens, gold)
nlp.entity.train(tokens, gold) nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags) nlp.tagger.train(tokens, gold.tags)
@ -152,6 +157,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False, def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
beam_width=None, cand_preproc=None): beam_width=None, cand_preproc=None):
nlp = Language(data_dir=model_dir) nlp = Language(data_dir=model_dir)
if nlp.lang == 'de':
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
if beam_width is not None: if beam_width is not None:
nlp.parser.cfg.beam_width = beam_width nlp.parser.cfg.beam_width = beam_width
scorer = Scorer() scorer = Scorer()
@ -200,6 +207,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
@plac.annotations( @plac.annotations(
language=("The language to train", "positional", None, str, ['en','de']),
train_loc=("Location of training file or directory"), train_loc=("Location of training file or directory"),
dev_loc=("Location of development file or directory"), dev_loc=("Location of development file or directory"),
model_dir=("Location of output model directory",), model_dir=("Location of output model directory",),
@ -211,19 +219,22 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
n_iter=("Number of training iterations", "option", "i", int), n_iter=("Number of training iterations", "option", "i", int),
verbose=("Verbose error reporting", "flag", "v", bool), verbose=("Verbose error reporting", "flag", "v", bool),
debug=("Debug mode", "flag", "d", bool), debug=("Debug mode", "flag", "d", bool),
pseudoprojective=("Use pseudo-projective parsing", "flag", "p", bool),
) )
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False, def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False): debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False, pseudoprojective=False):
lang = {'en':English, 'de':German}.get(language)
if not eval_only: if not eval_only:
gold_train = list(read_json_file(train_loc)) gold_train = list(read_json_file(train_loc))
train(English, gold_train, model_dir, train(lang, gold_train, model_dir,
feat_set='basic' if not debug else 'debug', feat_set='basic' if not debug else 'debug',
gold_preproc=gold_preproc, n_sents=n_sents, gold_preproc=gold_preproc, n_sents=n_sents,
corruption_level=corruption_level, n_iter=n_iter, corruption_level=corruption_level, n_iter=n_iter,
verbose=verbose) verbose=verbose,pseudoprojective=pseudoprojective)
if out_loc: if out_loc:
write_parses(English, dev_loc, model_dir, out_loc) write_parses(lang, dev_loc, model_dir, out_loc)
scorer = evaluate(English, list(read_json_file(dev_loc)), scorer = evaluate(lang, list(read_json_file(dev_loc)),
model_dir, gold_preproc=gold_preproc, verbose=verbose) model_dir, gold_preproc=gold_preproc, verbose=verbose)
print('TOK', scorer.token_acc) print('TOK', scorer.token_acc)
print('POS', scorer.tags_acc) print('POS', scorer.tags_acc)

View File

@ -0,0 +1,160 @@
#!/usr/bin/env python
from __future__ import division
from __future__ import unicode_literals
import os
from os import path
import shutil
import io
import random
import time
import gzip
import ujson
import plac
import cProfile
import pstats
import spacy.util
from spacy.de import German
from spacy.gold import GoldParse
from spacy.tagger import Tagger
from spacy.scorer import PRFScore
from spacy.tagger import P2_orth, P2_cluster, P2_shape, P2_prefix, P2_suffix, P2_pos, P2_lemma, P2_flags
from spacy.tagger import P1_orth, P1_cluster, P1_shape, P1_prefix, P1_suffix, P1_pos, P1_lemma, P1_flags
from spacy.tagger import W_orth, W_cluster, W_shape, W_prefix, W_suffix, W_pos, W_lemma, W_flags
from spacy.tagger import N1_orth, N1_cluster, N1_shape, N1_prefix, N1_suffix, N1_pos, N1_lemma, N1_flags
from spacy.tagger import N2_orth, N2_cluster, N2_shape, N2_prefix, N2_suffix, N2_pos, N2_lemma, N2_flags, N_CONTEXT_FIELDS
def default_templates():
return spacy.tagger.Tagger.default_templates()
def default_templates_without_clusters():
return (
(W_orth,),
(P1_lemma, P1_pos),
(P2_lemma, P2_pos),
(N1_orth,),
(N2_orth,),
(W_suffix,),
(W_prefix,),
(P1_pos,),
(P2_pos,),
(P1_pos, P2_pos),
(P1_pos, W_orth),
(P1_suffix,),
(N1_suffix,),
(W_shape,),
(W_flags,),
(N1_flags,),
(N2_flags,),
(P1_flags,),
(P2_flags,),
)
def make_tagger(vocab, templates):
model = spacy.tagger.TaggerModel(templates)
return spacy.tagger.Tagger(vocab,model)
def read_conll(file_):
def sentences():
words, tags = [], []
for line in file_:
line = line.strip()
if line:
word, tag = line.split('\t')[1::3][:2] # get column 1 and 4 (CoNLL09)
words.append(word)
tags.append(tag)
elif words:
yield words, tags
words, tags = [], []
if words:
yield words, tags
return [ s for s in sentences() ]
def score_model(score, nlp, words, gold_tags):
tokens = nlp.tokenizer.tokens_from_list(words)
assert(len(tokens) == len(gold_tags))
nlp.tagger(tokens)
for token, gold_tag in zip(tokens,gold_tags):
score.score_set(set([token.tag_]),set([gold_tag]))
def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=21):
# make shuffling deterministic
random.seed(seed)
# set up directory for model
pos_model_dir = path.join(model_dir, 'pos')
if path.exists(pos_model_dir):
shutil.rmtree(pos_model_dir)
os.mkdir(pos_model_dir)
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
nlp.tagger = make_tagger(nlp.vocab,default_templates())
print("Itn.\ttrain acc %\tdev acc %")
for itn in range(n_iter):
# train on train set
#train_acc = PRFScore()
correct, total = 0., 0.
for words, gold_tags in train_sents:
tokens = nlp.tokenizer.tokens_from_list(words)
correct += nlp.tagger.train(tokens, gold_tags)
total += len(words)
train_acc = correct/total
# test on dev set
dev_acc = PRFScore()
for words, gold_tags in dev_sents:
score_model(dev_acc, nlp, words, gold_tags)
random.shuffle(train_sents)
print('%d:\t%6.2f\t%6.2f' % (itn, 100*train_acc, 100*dev_acc.precision))
print('end training')
nlp.end_training(model_dir)
print('done')
@plac.annotations(
train_loc=("Location of CoNLL 09 formatted training file"),
dev_loc=("Location of CoNLL 09 formatted development file"),
model_dir=("Location of output model directory"),
eval_only=("Skip training, and only evaluate", "flag", "e", bool),
n_iter=("Number of training iterations", "option", "i", int),
)
def main(train_loc, dev_loc, model_dir, eval_only=False, n_iter=15):
# training
if not eval_only:
with io.open(train_loc, 'r', encoding='utf8') as trainfile_, \
io.open(dev_loc, 'r', encoding='utf8') as devfile_:
train_sents = read_conll(trainfile_)
dev_sents = read_conll(devfile_)
train(German, train_sents, dev_sents, model_dir, n_iter=n_iter)
# testing
with io.open(dev_loc, 'r', encoding='utf8') as file_:
dev_sents = read_conll(file_)
nlp = German(data_dir=model_dir)
dev_acc = PRFScore()
for words, gold_tags in dev_sents:
score_model(dev_acc, nlp, words, gold_tags)
print('POS: %6.2f %%' % (100*dev_acc.precision))
if __name__ == '__main__':
plac.call(main)

319
lang_data/de/abbrev.de.tab Normal file
View File

@ -0,0 +1,319 @@
# surface form lemma pos
# multiple values are separated by |
# empty lines and lines starting with # are being ignored
'' ''
\") \")
\n \n <nl> SP
\t \t <tab> SP
<space> SP
# example: Wie geht's?
's 's es
'S 'S es
# example: Haste mal 'nen Euro?
'n 'n ein
'ne 'ne eine
'nen 'nen einen
# example: Kommen S nur herein!
s' s' sie
S' S' sie
# example: Da haben wir's!
ich's ich|'s ich|es
du's du|'s du|es
er's er|'s er|es
sie's sie|'s sie|es
wir's wir|'s wir|es
ihr's ihr|'s ihr|es
# example: Die katze auf'm dach.
auf'm auf|'m auf|dem
unter'm unter|'m unter|dem
über'm über|'m über|dem
vor'm vor|'m vor|dem
hinter'm hinter|'m hinter|dem
# persons
B.A. B.A.
B.Sc. B.Sc.
Dipl. Dipl.
Dipl.-Ing. Dipl.-Ing.
Dr. Dr.
Fr. Fr.
Frl. Frl.
Hr. Hr.
Hrn. Hrn.
Frl. Frl.
Prof. Prof.
St. St.
Hrgs. Hrgs.
Hg. Hg.
a.Z. a.Z.
a.D. a.D.
h.c. h.c.
Jr. Jr.
jr. jr.
jun. jun.
sen. sen.
rer. rer.
Ing. Ing.
M.A. M.A.
Mr. Mr.
M.Sc. M.Sc.
nat. nat.
phil. phil.
# companies
Co. Co.
co. co.
Cie. Cie.
A.G. A.G.
G.m.b.H. G.m.b.H.
i.G. i.G.
e.V. e.V.
# popular german abbreviations
Abb. Abb.
Abk. Abk.
Abs. Abs.
Abt. Abt.
abzgl. abzgl.
allg. allg.
a.M. a.M.
Bd. Bd.
betr. betr.
Betr. Betr.
Biol. Biol.
biol. biol.
Bf. Bf.
Bhf. Bhf.
Bsp. Bsp.
bspw. bspw.
bzgl. bzgl.
bzw. bzw.
d.h. d.h.
dgl. dgl.
ebd. ebd.
ehem. ehem.
eigtl. eigtl.
entspr. entspr.
erm. erm.
ev. ev.
evtl. evtl.
Fa. Fa.
Fam. Fam.
geb. geb.
Gebr. Gebr.
gem. gem.
ggf. ggf.
ggü. ggü.
ggfs. ggfs.
gegr. gegr.
Hbf. Hbf.
Hrsg. Hrsg.
hrsg. hrsg.
i.A. i.A.
i.d.R. i.d.R.
inkl. inkl.
insb. insb.
i.O. i.O.
i.Tr. i.Tr.
i.V. i.V.
jur. jur.
kath. kath.
K.O. K.O.
lt. lt.
max. max.
m.E. m.E.
m.M. m.M.
mtl. mtl.
min. min.
mind. mind.
MwSt. MwSt.
Nr. Nr.
o.a. o.a.
o.ä. o.ä.
o.Ä. o.Ä.
o.g. o.g.
o.k. o.k.
O.K. O.K.
Orig. Orig.
orig. orig.
pers. pers.
Pkt. Pkt.
Red. Red.
röm. röm.
s.o. s.o.
sog. sog.
std. std.
stellv. stellv.
Str. Str.
tägl. tägl.
Tel. Tel.
u.a. u.a.
usf. usf.
u.s.w. u.s.w.
usw. usw.
u.U. u.U.
u.v.m. u.v.m.
uvm. uvm.
v.a. v.a.
vgl. vgl.
vllt. vllt.
v.l.n.r. v.l.n.r.
vlt. vlt.
Vol. Vol.
wiss. wiss.
Univ. Univ.
z.B. z.B.
z.b. z.b.
z.Bsp. z.Bsp.
z.T. z.T.
z.Z. z.Z.
zzgl. zzgl.
z.Zt. z.Zt.
# popular latin abbreviations
vs. vs.
adv. adv.
Chr. Chr.
A.C. A.C.
A.D. A.D.
e.g. e.g.
i.e. i.e.
al. al.
p.a. p.a.
P.S. P.S.
q.e.d. q.e.d.
R.I.P. R.I.P.
etc. etc.
incl. incl.
ca. ca.
n.Chr. n.Chr.
p.s. p.s.
v.Chr. v.Chr.
# popular english abbreviations
D.C. D.C.
N.Y. N.Y.
N.Y.C. N.Y.C.
U.S. U.S.
U.S.A. U.S.A.
L.A. L.A.
U.S.S. U.S.S.
# dates & time
Jan. Jan.
Feb. Feb.
Mrz. Mrz.
Mär. Mär.
Apr. Apr.
Jun. Jun.
Jul. Jul.
Aug. Aug.
Sep. Sep.
Sept. Sept.
Okt. Okt.
Nov. Nov.
Dez. Dez.
Mo. Mo.
Di. Di.
Mi. Mi.
Do. Do.
Fr. Fr.
Sa. Sa.
So. So.
Std. Std.
Jh. Jh.
Jhd. Jhd.
# numbers
Tsd. Tsd.
Mio. Mio.
Mrd. Mrd.
# countries & languages
engl. engl.
frz. frz.
lat. lat.
österr. österr.
# smileys
:) :)
<3 <3
;) ;)
(: (:
:( :(
-_- -_-
=) =)
:/ :/
:> :>
;-) ;-)
:Y :Y
:P :P
:-P :-P
:3 :3
=3 =3
xD xD
^_^ ^_^
=] =]
=D =D
<333 <333
:)) :))
:0 :0
-__- -__-
xDD xDD
o_o o_o
o_O o_O
V_V V_V
=[[ =[[
<33 <33
;p ;p
;D ;D
;-p ;-p
;( ;(
:p :p
:] :]
:O :O
:-/ :-/
:-) :-)
:((( :(((
:(( :((
:') :')
(^_^) (^_^)
(= (=
o.O o.O
# single letters
a. a.
b. b.
c. c.
d. d.
e. e.
f. f.
g. g.
h. h.
i. i.
j. j.
k. k.
l. l.
m. m.
n. n.
o. o.
p. p.
q. q.
r. r.
s. s.
t. t.
u. u.
v. v.
w. w.
x. x.
y. y.
z. z.
ä. ä.
ö. ö.
ü. ü.

194
lang_data/de/gazetteer.json Normal file
View File

@ -0,0 +1,194 @@
{
"Reddit": [
"PRODUCT",
{},
[
[{"lower": "reddit"}]
]
],
"SeptemberElevenAttacks": [
"EVENT",
{},
[
[
{"orth": "9/11"}
],
[
{"lower": "september"},
{"orth": "11"}
]
]
],
"Linux": [
"PRODUCT",
{},
[
[{"lower": "linux"}]
]
],
"Haskell": [
"PRODUCT",
{},
[
[{"lower": "haskell"}]
]
],
"HaskellCurry": [
"PERSON",
{},
[
[
{"lower": "haskell"},
{"lower": "curry"}
]
]
],
"Javascript": [
"PRODUCT",
{},
[
[{"lower": "javascript"}]
]
],
"CSS": [
"PRODUCT",
{},
[
[{"lower": "css"}],
[{"lower": "css3"}]
]
],
"displaCy": [
"PRODUCT",
{},
[
[{"lower": "displacy"}]
]
],
"spaCy": [
"PRODUCT",
{},
[
[{"orth": "spaCy"}]
]
],
"HTML": [
"PRODUCT",
{},
[
[{"lower": "html"}],
[{"lower": "html5"}]
]
],
"Python": [
"PRODUCT",
{},
[
[{"orth": "Python"}]
]
],
"Ruby": [
"PRODUCT",
{},
[
[{"orth": "Ruby"}]
]
],
"Digg": [
"PRODUCT",
{},
[
[{"lower": "digg"}]
]
],
"FoxNews": [
"ORG",
{},
[
[{"orth": "Fox"}],
[{"orth": "News"}]
]
],
"Google": [
"ORG",
{},
[
[{"lower": "google"}]
]
],
"Mac": [
"PRODUCT",
{},
[
[{"lower": "mac"}]
]
],
"Wikipedia": [
"PRODUCT",
{},
[
[{"lower": "wikipedia"}]
]
],
"Windows": [
"PRODUCT",
{},
[
[{"orth": "Windows"}]
]
],
"Dell": [
"ORG",
{},
[
[{"lower": "dell"}]
]
],
"Facebook": [
"ORG",
{},
[
[{"lower": "facebook"}]
]
],
"Blizzard": [
"ORG",
{},
[
[{"orth": "Blizzard"}]
]
],
"Ubuntu": [
"ORG",
{},
[
[{"orth": "Ubuntu"}]
]
],
"Youtube": [
"PRODUCT",
{},
[
[{"lower": "youtube"}]
]
],
"false_positives": [
null,
{},
[
[{"orth": "Shit"}],
[{"orth": "Weed"}],
[{"orth": "Cool"}],
[{"orth": "Btw"}],
[{"orth": "Bah"}],
[{"orth": "Bullshit"}],
[{"orth": "Lol"}],
[{"orth": "Yo"}, {"lower": "dawg"}],
[{"orth": "Yay"}],
[{"orth": "Ahh"}],
[{"orth": "Yea"}],
[{"orth": "Bah"}]
]
]
}

View File

@ -1,5 +1,7 @@
# coding=utf8 # coding=utf8
import json import json
import io
import itertools
contractions = {} contractions = {}
@ -262,14 +264,30 @@ def get_token_properties(token, capitalize=False, remove_contractions=False):
props["F"] = token props["F"] = token
return props return props
def create_entry(token, endings, capitalize=False, remove_contractions=False):
def create_entry(token, endings, capitalize=False, remove_contractions=False):
properties = [] properties = []
properties.append(get_token_properties(token, capitalize=capitalize, remove_contractions=remove_contractions)) properties.append(get_token_properties(token, capitalize=capitalize, remove_contractions=remove_contractions))
for e in endings: for e in endings:
properties.append(get_token_properties(e, remove_contractions=remove_contractions)) properties.append(get_token_properties(e, remove_contractions=remove_contractions))
return properties return properties
FIELDNAMES = ['F','L','pos']
def read_hardcoded(stream):
hc_specials = {}
for line in stream:
line = line.strip()
if line.startswith('#') or not line:
continue
key,_,rest = line.partition('\t')
values = []
for annotation in zip(*[ e.split('|') for e in rest.split('\t') ]):
values.append({ k:v for k,v in itertools.izip_longest(FIELDNAMES,annotation) if v })
hc_specials[key] = values
return hc_specials
def generate_specials(): def generate_specials():
specials = {} specials = {}
@ -303,7 +321,10 @@ def generate_specials():
specials[special] = create_entry(token, endings, capitalize=True, remove_contractions=True) specials[special] = create_entry(token, endings, capitalize=True, remove_contractions=True)
# add in hardcoded specials # add in hardcoded specials
specials = dict(specials, **hardcoded_specials) # changed it so it generates them from a file
with io.open('abbrev.de.tab','r',encoding='utf8') as abbrev_:
hc_specials = read_hardcoded(abbrev_)
specials = dict(specials, **hc_specials)
return specials return specials

View File

@ -1,3 +1,6 @@
\.\.\. \.\.\.
(?<=[a-z])\.(?=[A-Z]) (?<=[a-z])\.(?=[A-Z])
(?<=[a-zA-Z])-(?=[a-zA-z]) (?<=[a-zöäüßA-ZÖÄÜ"]):(?=[a-zöäüßA-ZÖÄÜ])
(?<=[a-zöäüßA-ZÖÄÜ"])>(?=[a-zöäüßA-ZÖÄÜ])
(?<=[a-zöäüßA-ZÖÄÜ"])<(?=[a-zöäüßA-ZÖÄÜ])
(?<=[a-zöäüßA-ZÖÄÜ"])=(?=[a-zöäüßA-ZÖÄÜ])

View File

@ -5,6 +5,7 @@
{ {
* *
< <
>
$ $
£ £
@ -20,3 +21,7 @@ a-
.... ....
... ...
»
_
§

View File

@ -1,27 +1,4 @@
{ {
"\t": [
{
"F": "\t",
"pos": "SP"
}
],
"\n": [
{
"F": "\n",
"pos": "SP"
}
],
" ": [
{
"F": " ",
"pos": "SP"
}
],
"\")": [
{
"F": "\")"
}
],
"''": [ "''": [
{ {
"F": "''" "F": "''"
@ -217,6 +194,11 @@
"F": "<333" "F": "<333"
} }
], ],
"<space>": [
{
"F": "SP"
}
],
"=)": [ "=)": [
{ {
"F": "=)" "F": "=)"
@ -267,6 +249,16 @@
"F": "Abk." "F": "Abk."
} }
], ],
"Abs.": [
{
"F": "Abs."
}
],
"Abt.": [
{
"F": "Abt."
}
],
"Apr.": [ "Apr.": [
{ {
"F": "Apr." "F": "Apr."
@ -277,6 +269,26 @@
"F": "Aug." "F": "Aug."
} }
], ],
"B.A.": [
{
"F": "B.A."
}
],
"B.Sc.": [
{
"F": "B.Sc."
}
],
"Bd.": [
{
"F": "Bd."
}
],
"Betr.": [
{
"F": "Betr."
}
],
"Bf.": [ "Bf.": [
{ {
"F": "Bf." "F": "Bf."
@ -292,6 +304,11 @@
"F": "Biol." "F": "Biol."
} }
], ],
"Bsp.": [
{
"F": "Bsp."
}
],
"Chr.": [ "Chr.": [
{ {
"F": "Chr." "F": "Chr."
@ -342,6 +359,16 @@
"F": "Dr." "F": "Dr."
} }
], ],
"Fa.": [
{
"F": "Fa."
}
],
"Fam.": [
{
"F": "Fam."
}
],
"Feb.": [ "Feb.": [
{ {
"F": "Feb." "F": "Feb."
@ -387,6 +414,16 @@
"F": "Hrgs." "F": "Hrgs."
} }
], ],
"Hrn.": [
{
"F": "Hrn."
}
],
"Hrsg.": [
{
"F": "Hrsg."
}
],
"Ing.": [ "Ing.": [
{ {
"F": "Ing." "F": "Ing."
@ -397,11 +434,21 @@
"F": "Jan." "F": "Jan."
} }
], ],
"Jh.": [
{
"F": "Jh."
}
],
"Jhd.": [ "Jhd.": [
{ {
"F": "Jhd." "F": "Jhd."
} }
], ],
"Jr.": [
{
"F": "Jr."
}
],
"Jul.": [ "Jul.": [
{ {
"F": "Jul." "F": "Jul."
@ -412,21 +459,61 @@
"F": "Jun." "F": "Jun."
} }
], ],
"K.O.": [
{
"F": "K.O."
}
],
"L.A.": [
{
"F": "L.A."
}
],
"M.A.": [
{
"F": "M.A."
}
],
"M.Sc.": [
{
"F": "M.Sc."
}
],
"Mi.": [ "Mi.": [
{ {
"F": "Mi." "F": "Mi."
} }
], ],
"Mio.": [
{
"F": "Mio."
}
],
"Mo.": [ "Mo.": [
{ {
"F": "Mo." "F": "Mo."
} }
], ],
"Mr.": [
{
"F": "Mr."
}
],
"Mrd.": [
{
"F": "Mrd."
}
],
"Mrz.": [ "Mrz.": [
{ {
"F": "Mrz." "F": "Mrz."
} }
], ],
"MwSt.": [
{
"F": "MwSt."
}
],
"M\u00e4r.": [ "M\u00e4r.": [
{ {
"F": "M\u00e4r." "F": "M\u00e4r."
@ -452,16 +539,31 @@
"F": "Nr." "F": "Nr."
} }
], ],
"O.K.": [
{
"F": "O.K."
}
],
"Okt.": [ "Okt.": [
{ {
"F": "Okt." "F": "Okt."
} }
], ],
"Orig.": [
{
"F": "Orig."
}
],
"P.S.": [ "P.S.": [
{ {
"F": "P.S." "F": "P.S."
} }
], ],
"Pkt.": [
{
"F": "Pkt."
}
],
"Prof.": [ "Prof.": [
{ {
"F": "Prof." "F": "Prof."
@ -472,6 +574,11 @@
"F": "R.I.P." "F": "R.I.P."
} }
], ],
"Red.": [
{
"F": "Red."
}
],
"S'": [ "S'": [
{ {
"F": "S'", "F": "S'",
@ -503,6 +610,41 @@
"F": "St." "F": "St."
} }
], ],
"Std.": [
{
"F": "Std."
}
],
"Str.": [
{
"F": "Str."
}
],
"Tel.": [
{
"F": "Tel."
}
],
"Tsd.": [
{
"F": "Tsd."
}
],
"U.S.": [
{
"F": "U.S."
}
],
"U.S.A.": [
{
"F": "U.S.A."
}
],
"U.S.S.": [
{
"F": "U.S.S."
}
],
"Univ.": [ "Univ.": [
{ {
"F": "Univ." "F": "Univ."
@ -513,6 +655,30 @@
"F": "V_V" "F": "V_V"
} }
], ],
"Vol.": [
{
"F": "Vol."
}
],
"\\\")": [
{
"F": "\\\")"
}
],
"\\n": [
{
"F": "\\n",
"L": "<nl>",
"pos": "SP"
}
],
"\\t": [
{
"F": "\\t",
"L": "<tab>",
"pos": "SP"
}
],
"^_^": [ "^_^": [
{ {
"F": "^_^" "F": "^_^"
@ -528,6 +694,11 @@
"F": "a.D." "F": "a.D."
} }
], ],
"a.M.": [
{
"F": "a.M."
}
],
"a.Z.": [ "a.Z.": [
{ {
"F": "a.Z." "F": "a.Z."
@ -548,9 +719,15 @@
"F": "al." "F": "al."
} }
], ],
"allg.": [
{
"F": "allg."
}
],
"auf'm": [ "auf'm": [
{ {
"F": "auf" "F": "auf",
"L": "auf"
}, },
{ {
"F": "'m", "F": "'m",
@ -572,11 +749,31 @@
"F": "biol." "F": "biol."
} }
], ],
"bspw.": [
{
"F": "bspw."
}
],
"bzgl.": [
{
"F": "bzgl."
}
],
"bzw.": [
{
"F": "bzw."
}
],
"c.": [ "c.": [
{ {
"F": "c." "F": "c."
} }
], ],
"ca.": [
{
"F": "ca."
}
],
"co.": [ "co.": [
{ {
"F": "co." "F": "co."
@ -587,9 +784,20 @@
"F": "d." "F": "d."
} }
], ],
"d.h.": [
{
"F": "d.h."
}
],
"dgl.": [
{
"F": "dgl."
}
],
"du's": [ "du's": [
{ {
"F": "du" "F": "du",
"L": "du"
}, },
{ {
"F": "'s", "F": "'s",
@ -611,19 +819,35 @@
"F": "e.g." "F": "e.g."
} }
], ],
"ebd.": [
{
"F": "ebd."
}
],
"ehem.": [ "ehem.": [
{ {
"F": "ehem." "F": "ehem."
} }
], ],
"eigtl.": [
{
"F": "eigtl."
}
],
"engl.": [ "engl.": [
{ {
"F": "engl." "F": "engl."
} }
], ],
"entspr.": [
{
"F": "entspr."
}
],
"er's": [ "er's": [
{ {
"F": "er" "F": "er",
"L": "er"
}, },
{ {
"F": "'s", "F": "'s",
@ -640,11 +864,26 @@
"F": "etc." "F": "etc."
} }
], ],
"ev.": [
{
"F": "ev."
}
],
"evtl.": [
{
"F": "evtl."
}
],
"f.": [ "f.": [
{ {
"F": "f." "F": "f."
} }
], ],
"frz.": [
{
"F": "frz."
}
],
"g.": [ "g.": [
{ {
"F": "g." "F": "g."
@ -660,6 +899,11 @@
"F": "gegr." "F": "gegr."
} }
], ],
"gem.": [
{
"F": "gem."
}
],
"ggf.": [ "ggf.": [
{ {
"F": "ggf." "F": "ggf."
@ -687,23 +931,39 @@
], ],
"hinter'm": [ "hinter'm": [
{ {
"F": "hinter" "F": "hinter",
"L": "hinter"
}, },
{ {
"F": "'m", "F": "'m",
"L": "dem" "L": "dem"
} }
], ],
"hrsg.": [
{
"F": "hrsg."
}
],
"i.": [ "i.": [
{ {
"F": "i." "F": "i."
} }
], ],
"i.A.": [
{
"F": "i.A."
}
],
"i.G.": [ "i.G.": [
{ {
"F": "i.G." "F": "i.G."
} }
], ],
"i.O.": [
{
"F": "i.O."
}
],
"i.Tr.": [ "i.Tr.": [
{ {
"F": "i.Tr." "F": "i.Tr."
@ -714,6 +974,11 @@
"F": "i.V." "F": "i.V."
} }
], ],
"i.d.R.": [
{
"F": "i.d.R."
}
],
"i.e.": [ "i.e.": [
{ {
"F": "i.e." "F": "i.e."
@ -721,7 +986,8 @@
], ],
"ich's": [ "ich's": [
{ {
"F": "ich" "F": "ich",
"L": "ich"
}, },
{ {
"F": "'s", "F": "'s",
@ -730,7 +996,8 @@
], ],
"ihr's": [ "ihr's": [
{ {
"F": "ihr" "F": "ihr",
"L": "ihr"
}, },
{ {
"F": "'s", "F": "'s",
@ -757,6 +1024,11 @@
"F": "j." "F": "j."
} }
], ],
"jr.": [
{
"F": "jr."
}
],
"jun.": [ "jun.": [
{ {
"F": "jun." "F": "jun."
@ -772,11 +1044,21 @@
"F": "k." "F": "k."
} }
], ],
"kath.": [
{
"F": "kath."
}
],
"l.": [ "l.": [
{ {
"F": "l." "F": "l."
} }
], ],
"lat.": [
{
"F": "lat."
}
],
"lt.": [ "lt.": [
{ {
"F": "lt." "F": "lt."
@ -787,11 +1069,46 @@
"F": "m." "F": "m."
} }
], ],
"m.E.": [
{
"F": "m.E."
}
],
"m.M.": [
{
"F": "m.M."
}
],
"max.": [
{
"F": "max."
}
],
"min.": [
{
"F": "min."
}
],
"mind.": [
{
"F": "mind."
}
],
"mtl.": [
{
"F": "mtl."
}
],
"n.": [ "n.": [
{ {
"F": "n." "F": "n."
} }
], ],
"n.Chr.": [
{
"F": "n.Chr."
}
],
"nat.": [ "nat.": [
{ {
"F": "nat." "F": "nat."
@ -807,6 +1124,31 @@
"F": "o.O" "F": "o.O"
} }
], ],
"o.a.": [
{
"F": "o.a."
}
],
"o.g.": [
{
"F": "o.g."
}
],
"o.k.": [
{
"F": "o.k."
}
],
"o.\u00c4.": [
{
"F": "o.\u00c4."
}
],
"o.\u00e4.": [
{
"F": "o.\u00e4."
}
],
"o_O": [ "o_O": [
{ {
"F": "o_O" "F": "o_O"
@ -817,6 +1159,11 @@
"F": "o_o" "F": "o_o"
} }
], ],
"orig.": [
{
"F": "orig."
}
],
"p.": [ "p.": [
{ {
"F": "p." "F": "p."
@ -827,6 +1174,21 @@
"F": "p.a." "F": "p.a."
} }
], ],
"p.s.": [
{
"F": "p.s."
}
],
"pers.": [
{
"F": "pers."
}
],
"phil.": [
{
"F": "phil."
}
],
"q.": [ "q.": [
{ {
"F": "q." "F": "q."
@ -847,6 +1209,11 @@
"F": "rer." "F": "rer."
} }
], ],
"r\u00f6m.": [
{
"F": "r\u00f6m."
}
],
"s'": [ "s'": [
{ {
"F": "s'", "F": "s'",
@ -858,6 +1225,11 @@
"F": "s." "F": "s."
} }
], ],
"s.o.": [
{
"F": "s.o."
}
],
"sen.": [ "sen.": [
{ {
"F": "sen." "F": "sen."
@ -865,23 +1237,49 @@
], ],
"sie's": [ "sie's": [
{ {
"F": "sie" "F": "sie",
"L": "sie"
}, },
{ {
"F": "'s", "F": "'s",
"L": "es" "L": "es"
} }
], ],
"sog.": [
{
"F": "sog."
}
],
"std.": [
{
"F": "std."
}
],
"stellv.": [
{
"F": "stellv."
}
],
"t.": [ "t.": [
{ {
"F": "t." "F": "t."
} }
], ],
"t\u00e4gl.": [
{
"F": "t\u00e4gl."
}
],
"u.": [ "u.": [
{ {
"F": "u." "F": "u."
} }
], ],
"u.U.": [
{
"F": "u.U."
}
],
"u.a.": [ "u.a.": [
{ {
"F": "u.a." "F": "u.a."
@ -892,28 +1290,75 @@
"F": "u.s.w." "F": "u.s.w."
} }
], ],
"u.v.m.": [
{
"F": "u.v.m."
}
],
"unter'm": [ "unter'm": [
{ {
"F": "unter" "F": "unter",
"L": "unter"
}, },
{ {
"F": "'m", "F": "'m",
"L": "dem" "L": "dem"
} }
], ],
"usf.": [
{
"F": "usf."
}
],
"usw.": [
{
"F": "usw."
}
],
"uvm.": [
{
"F": "uvm."
}
],
"v.": [ "v.": [
{ {
"F": "v." "F": "v."
} }
], ],
"v.Chr.": [
{
"F": "v.Chr."
}
],
"v.a.": [
{
"F": "v.a."
}
],
"v.l.n.r.": [
{
"F": "v.l.n.r."
}
],
"vgl.": [ "vgl.": [
{ {
"F": "vgl." "F": "vgl."
} }
], ],
"vllt.": [
{
"F": "vllt."
}
],
"vlt.": [
{
"F": "vlt."
}
],
"vor'm": [ "vor'm": [
{ {
"F": "vor" "F": "vor",
"L": "vor"
}, },
{ {
"F": "'m", "F": "'m",
@ -932,13 +1377,19 @@
], ],
"wir's": [ "wir's": [
{ {
"F": "wir" "F": "wir",
"L": "wir"
}, },
{ {
"F": "'s", "F": "'s",
"L": "es" "L": "es"
} }
], ],
"wiss.": [
{
"F": "wiss."
}
],
"x.": [ "x.": [
{ {
"F": "x." "F": "x."
@ -969,19 +1420,60 @@
"F": "z.B." "F": "z.B."
} }
], ],
"z.Bsp.": [
{
"F": "z.Bsp."
}
],
"z.T.": [
{
"F": "z.T."
}
],
"z.Z.": [ "z.Z.": [
{ {
"F": "z.Z." "F": "z.Z."
} }
], ],
"z.Zt.": [
{
"F": "z.Zt."
}
],
"z.b.": [
{
"F": "z.b."
}
],
"zzgl.": [ "zzgl.": [
{ {
"F": "zzgl." "F": "zzgl."
} }
], ],
"\u00e4.": [
{
"F": "\u00e4."
}
],
"\u00f6.": [
{
"F": "\u00f6."
}
],
"\u00f6sterr.": [
{
"F": "\u00f6sterr."
}
],
"\u00fc.": [
{
"F": "\u00fc."
}
],
"\u00fcber'm": [ "\u00fcber'm": [
{ {
"F": "\u00fcber" "F": "\u00fcber",
"L": "\u00fcber"
}, },
{ {
"F": "'m", "F": "'m",

View File

@ -13,14 +13,61 @@
; ;
' '
«
_
'' ''
's 's
'S 'S
s s
S S
°
\.\. \.\.
\.\.\. \.\.\.
\.\.\.\. \.\.\.\.
(?<=[a-z0-9)\]"'%\)])\. (?<=[a-zäöüßÖÄÜ)\]"'´«‘’%\)²“”])\.
\-\-
´
(?<=[0-9])km²
(?<=[0-9])m²
(?<=[0-9])cm²
(?<=[0-9])mm²
(?<=[0-9])km³
(?<=[0-9])m³
(?<=[0-9])cm³
(?<=[0-9])mm³
(?<=[0-9])ha
(?<=[0-9])km (?<=[0-9])km
(?<=[0-9])m
(?<=[0-9])cm
(?<=[0-9])mm
(?<=[0-9])µm
(?<=[0-9])nm
(?<=[0-9])yd
(?<=[0-9])in
(?<=[0-9])ft
(?<=[0-9])kg
(?<=[0-9])g
(?<=[0-9])mg
(?<=[0-9])µg
(?<=[0-9])t
(?<=[0-9])lb
(?<=[0-9])oz
(?<=[0-9])m/s
(?<=[0-9])km/h
(?<=[0-9])mph
(?<=[0-9])°C
(?<=[0-9])°K
(?<=[0-9])°F
(?<=[0-9])hPa
(?<=[0-9])Pa
(?<=[0-9])mbar
(?<=[0-9])mb
(?<=[0-9])T
(?<=[0-9])G
(?<=[0-9])M
(?<=[0-9])K
(?<=[0-9])kb

View File

@ -47,6 +47,7 @@ MOD_NAMES = [
'spacy.syntax._state', 'spacy.syntax._state',
'spacy.tokenizer', 'spacy.tokenizer',
'spacy.syntax.parser', 'spacy.syntax.parser',
'spacy.syntax.nonproj',
'spacy.syntax.transition_system', 'spacy.syntax.transition_system',
'spacy.syntax.arc_eager', 'spacy.syntax.arc_eager',
'spacy.syntax._parse_features', 'spacy.syntax._parse_features',

View File

@ -6,4 +6,4 @@ from ..language import Language
class German(Language): class German(Language):
pass lang = 'de'

View File

@ -14,6 +14,8 @@ try:
except ImportError: except ImportError:
import json import json
from .syntax import nonproj
def tags_to_entities(tags): def tags_to_entities(tags):
entities = [] entities = []
@ -237,33 +239,13 @@ cdef class GoldParse:
self.labels[i] = annot_tuples[4][gold_i] self.labels[i] = annot_tuples[4][gold_i]
self.ner[i] = annot_tuples[5][gold_i] self.ner[i] = annot_tuples[5][gold_i]
# If we have any non-projective arcs, i.e. crossing brackets, consider cycle = nonproj.contains_cycle(self.heads)
# the heads for those words missing in the gold-standard. if cycle != None:
# This way, we can train from these sentences raise Exception("Cycle found: %s" % cycle)
cdef int w1, w2, h1, h2
if make_projective:
heads = list(self.heads)
for w1 in range(self.length):
if heads[w1] is not None:
h1 = heads[w1]
for w2 in range(w1+1, self.length):
if heads[w2] is not None:
h2 = heads[w2]
if _arcs_cross(w1, h1, w2, h2):
self.heads[w1] = None
self.labels[w1] = ''
self.heads[w2] = None
self.labels[w2] = ''
# Check there are no cycles in the dependencies, i.e. we are a tree if make_projective:
for w in range(self.length): proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels)
seen = set([w]) self.heads = proj_heads
head = w
while self.heads[head] != head and self.heads[head] != None:
head = self.heads[head]
if head in seen:
raise Exception("Cycle found: %s" % seen)
seen.add(head)
self.brackets = {} self.brackets = {}
for (gold_start, gold_end, label_str) in brackets: for (gold_start, gold_end, label_str) in brackets:
@ -278,25 +260,18 @@ cdef class GoldParse:
@property @property
def is_projective(self): def is_projective(self):
heads = list(self.heads) return not nonproj.is_nonproj_tree(self.heads)
for w1 in range(self.length):
if heads[w1] is not None:
h1 = heads[w1]
for w2 in range(self.length):
if heads[w2] is not None and _arcs_cross(w1, h1, w2, heads[w2]):
return False
return True
cdef int _arcs_cross(int w1, int h1, int w2, int h2) except -1:
if w1 > h1:
w1, h1 = h1, w1
if w2 > h2:
w2, h2 = h2, w2
if w1 > w2:
w1, h1, w2, h2 = w2, h2, w1, h1
return w1 < w2 < h1 < h2 or w1 < w2 == h2 < h1
def is_punct_label(label): def is_punct_label(label):
return label == 'P' or label.lower() == 'punct' return label == 'P' or label.lower() == 'punct'

0
spacy/syntax/nonproj.pxd Normal file
View File

200
spacy/syntax/nonproj.pyx Normal file
View File

@ -0,0 +1,200 @@
from copy import copy
from collections import Counter
from ..tokens.doc cimport Doc
from spacy.attrs import DEP, HEAD
def ancestors(tokenid, heads):
# returns all words going from the word up the path to the root
# the path to root cannot be longer than the number of words in the sentence
# this function ends after at most len(heads) steps
# because it would otherwise loop indefinitely on cycles
head = tokenid
cnt = 0
while heads[head] != head and cnt < len(heads):
head = heads[head]
cnt += 1
yield head
if head == None:
break
def contains_cycle(heads):
# in an acyclic tree, the path from each word following
# the head relation upwards always ends at the root node
for tokenid in range(len(heads)):
seen = set([tokenid])
for ancestor in ancestors(tokenid,heads):
if ancestor in seen:
return seen
seen.add(ancestor)
return None
def is_nonproj_arc(tokenid, heads):
# definition (e.g. Havelka 2007): an arc h -> d, h < d is non-projective
# if there is a token k, h < k < d such that h is not
# an ancestor of k. Same for h -> d, h > d
head = heads[tokenid]
if head == tokenid: # root arcs cannot be non-projective
return False
elif head == None: # unattached tokens cannot be non-projective
return False
start, end = (head+1, tokenid) if head < tokenid else (tokenid+1, head)
for k in range(start,end):
for ancestor in ancestors(k,heads):
if ancestor == None: # for unattached tokens/subtrees
break
elif ancestor == head: # normal case: k dominated by h
break
else: # head not in ancestors: d -> h is non-projective
return True
return False
def is_nonproj_tree(heads):
# a tree is non-projective if at least one arc is non-projective
return any( is_nonproj_arc(word,heads) for word in range(len(heads)) )
cdef class PseudoProjectivity:
# implements the projectivize/deprojectivize mechanism in Nivre & Nilsson 2005
# for doing pseudo-projective parsing
# implementation uses the HEAD decoration scheme
delimiter = '||'
@classmethod
def decompose(cls, label):
return label.partition(cls.delimiter)[::2]
@classmethod
def is_decorated(cls, label):
return label.find(cls.delimiter) != -1
@classmethod
def preprocess_training_data(cls, gold_tuples, label_freq_cutoff=30):
preprocessed = []
freqs = Counter()
for raw_text, sents in gold_tuples:
prepro_sents = []
for (ids, words, tags, heads, labels, iob), ctnts in sents:
proj_heads,deco_labels = cls.projectivize(heads,labels)
# set the label to ROOT for each root dependent
deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ]
# count label frequencies
if label_freq_cutoff > 0:
freqs.update( label for label in deco_labels if cls.is_decorated(label) )
prepro_sents.append(((ids,words,tags,proj_heads,deco_labels,iob), ctnts))
preprocessed.append((raw_text, prepro_sents))
if label_freq_cutoff > 0:
return cls._filter_labels(preprocessed,label_freq_cutoff,freqs)
return preprocessed
@classmethod
def projectivize(cls, heads, labels):
# use the algorithm by Nivre & Nilsson 2005
# assumes heads to be a proper tree, i.e. connected and cycle-free
# returns a new pair (heads,labels) which encode
# a projective and decorated tree
proj_heads = copy(heads)
smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads)
if smallest_np_arc == None: # this sentence is already projective
return proj_heads, copy(labels)
while smallest_np_arc != None:
cls._lift(smallest_np_arc, proj_heads)
smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads)
deco_labels = cls._decorate(heads, proj_heads, labels)
return proj_heads, deco_labels
@classmethod
def deprojectivize(cls, Doc tokens):
# reattach arcs with decorated labels (following HEAD scheme)
# for each decorated arc X||Y, search top-down, left-to-right,
# breadth-first until hitting a Y then make this the new head
parse = tokens.to_array([HEAD, DEP])
labels = [ tokens.vocab.strings[int(p[1])] for p in parse ]
for token in tokens:
if cls.is_decorated(token.dep_):
newlabel,headlabel = cls.decompose(token.dep_)
newhead = cls._find_new_head(token,headlabel)
parse[token.i,1] = tokens.vocab.strings[newlabel]
parse[token.i,0] = newhead.i - token.i
tokens.from_array([HEAD, DEP],parse)
@classmethod
def _decorate(cls, heads, proj_heads, labels):
# uses decoration scheme HEAD from Nivre & Nilsson 2005
assert(len(heads) == len(proj_heads) == len(labels))
deco_labels = []
for tokenid,head in enumerate(heads):
if head != proj_heads[tokenid]:
deco_labels.append('%s%s%s' % (labels[tokenid],cls.delimiter,labels[head]))
else:
deco_labels.append(labels[tokenid])
return deco_labels
@classmethod
def _get_smallest_nonproj_arc(cls, heads):
# return the smallest non-proj arc or None
# where size is defined as the distance between dep and head
# and ties are broken left to right
smallest_size = float('inf')
smallest_np_arc = None
for tokenid,head in enumerate(heads):
size = abs(tokenid-head)
if size < smallest_size and is_nonproj_arc(tokenid,heads):
smallest_size = size
smallest_np_arc = tokenid
return smallest_np_arc
@classmethod
def _lift(cls, tokenid, heads):
# reattaches a word to it's grandfather
head = heads[tokenid]
ghead = heads[head]
# attach to ghead if head isn't attached to root else attach to root
heads[tokenid] = ghead if head != ghead else tokenid
@classmethod
def _find_new_head(cls, token, headlabel):
# search through the tree starting from root
# returns the id of the first descendant with the given label
# if there is none, return the current head (no change)
queue = [token.head]
while queue:
next_queue = []
for qtoken in queue:
for child in qtoken.children:
if child == token:
continue
if child.dep_ == headlabel:
return child
next_queue.append(child)
queue = next_queue
return token.head
@classmethod
def _filter_labels(cls, gold_tuples, cutoff, freqs):
# throw away infrequent decorated labels
# can't learn them reliably anyway and keeps label set smaller
filtered = []
for raw_text, sents in gold_tuples:
filtered_sents = []
for (ids, words, tags, heads, labels, iob), ctnts in sents:
filtered_labels = [ cls.decompose(label)[0] if freqs.get(label,cutoff) < cutoff else label for label in labels ]
filtered_sents.append(((ids,words,tags,heads,filtered_labels,iob), ctnts))
filtered.append((raw_text, filtered_sents))
return filtered

View File

@ -15,5 +15,6 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser: cdef class Parser:
cdef readonly ParserModel model cdef readonly ParserModel model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef int _projectivize
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil

View File

@ -12,12 +12,12 @@ from cpython.exc cimport PyErr_CheckSignals
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport malloc, calloc, free from libc.stdlib cimport malloc, calloc, free
import random
import os.path import os.path
from os import path from os import path
import shutil import shutil
import json import json
import sys import sys
from .nonproj import PseudoProjectivity
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
@ -79,9 +79,10 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser: cdef class Parser:
def __init__(self, StringStore strings, transition_system, ParserModel model): def __init__(self, StringStore strings, transition_system, ParserModel model, int projectivize = 0):
self.moves = transition_system self.moves = transition_system
self.model = model self.model = model
self._projectivize = projectivize
@classmethod @classmethod
def from_dir(cls, model_dir, strings, transition_system): def from_dir(cls, model_dir, strings, transition_system):
@ -93,9 +94,10 @@ cdef class Parser:
moves = transition_system(strings, cfg.labels) moves = transition_system(strings, cfg.labels)
templates = get_templates(cfg.features) templates = get_templates(cfg.features)
model = ParserModel(templates) model = ParserModel(templates)
project = cfg.projectivize if hasattr(cfg,'projectivize') else False
if path.exists(path.join(model_dir, 'model')): if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model')) model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model) return cls(strings, moves, model, project)
@classmethod @classmethod
def load(cls, pkg_or_str_or_file, vocab): def load(cls, pkg_or_str_or_file, vocab):
@ -114,6 +116,9 @@ cdef class Parser:
tokens.is_parsed = True tokens.is_parsed = True
# Check for KeyboardInterrupt etc. Untested # Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals() PyErr_CheckSignals()
# projectivize output
if self._projectivize:
PseudoProjectivity.deprojectivize(tokens)
def pipe(self, stream, int batch_size=1000, int n_threads=2): def pipe(self, stream, int batch_size=1000, int n_threads=2):
cdef Pool mem = Pool() cdef Pool mem = Pool()

View File

@ -143,7 +143,7 @@ cdef class Tagger:
@classmethod @classmethod
def blank(cls, vocab, templates): def blank(cls, vocab, templates):
model = TaggerModel(N_CONTEXT_FIELDS, templates) model = TaggerModel(templates)
return cls(vocab, model) return cls(vocab, model)
@classmethod @classmethod
@ -153,10 +153,8 @@ cdef class Tagger:
@classmethod @classmethod
def from_package(cls, pkg, vocab): def from_package(cls, pkg, vocab):
# TODO: templates.json deprecated? not present in latest package # TODO: templates.json deprecated? not present in latest package
templates = cls.default_templates() # templates = cls.default_templates()
# templates = package.load_utf8(json.load, templates = pkg.load_json(('pos', 'templates.json'), default=cls.default_templates())
# 'pos', 'templates.json',
# default=cls.default_templates())
model = TaggerModel(templates) model = TaggerModel(templates)
if pkg.has_file('pos', 'model'): if pkg.has_file('pos', 'model'):
@ -221,7 +219,7 @@ cdef class Tagger:
def train(self, Doc tokens, object gold_tag_strs): def train(self, Doc tokens, object gold_tag_strs):
assert len(tokens) == len(gold_tag_strs) assert len(tokens) == len(gold_tag_strs)
for tag in gold_tag_strs: for tag in gold_tag_strs:
if tag not in self.tag_names: if tag != None and tag not in self.tag_names:
msg = ("Unrecognized gold tag: %s. tag_map.json must contain all" msg = ("Unrecognized gold tag: %s. tag_map.json must contain all"
"gold tags, to maintain coarse-grained mapping.") "gold tags, to maintain coarse-grained mapping.")
raise ValueError(msg % tag) raise ValueError(msg % tag)
@ -234,10 +232,9 @@ cdef class Tagger:
nr_feat=self.model.nr_feat) nr_feat=self.model.nr_feat)
for i in range(tokens.length): for i in range(tokens.length):
self.model.set_featuresC(&eg.c, tokens.c, i) self.model.set_featuresC(&eg.c, tokens.c, i)
eg.set_label(golds[i]) eg.costs = [ 1 if golds[i] not in (c, -1) else 0 for c in xrange(eg.nr_class) ]
self.model.set_scoresC(eg.c.scores, self.model.set_scoresC(eg.c.scores,
eg.c.features, eg.c.nr_feat) eg.c.features, eg.c.nr_feat)
self.model.updateC(&eg.c) self.model.updateC(&eg.c)
self.vocab.morphology.assign_tag(&tokens.c[i], eg.guess) self.vocab.morphology.assign_tag(&tokens.c[i], eg.guess)

View File

@ -0,0 +1,136 @@
from __future__ import unicode_literals
import pytest
from spacy.attrs import DEP, HEAD
import numpy
from spacy.syntax.nonproj import ancestors, contains_cycle, is_nonproj_arc, is_nonproj_tree, PseudoProjectivity
def test_ancestors():
tree = [1,2,2,4,5,2,2]
cyclic_tree = [1,2,2,4,5,3,2]
partial_tree = [1,2,2,4,5,None,2]
multirooted_tree = [3,2,0,3,3,7,7,3,7,10,7,10,11,12,18,16,18,17,12,3]
assert([ a for a in ancestors(3,tree) ] == [4,5,2])
assert([ a for a in ancestors(3,cyclic_tree) ] == [4,5,3,4,5,3,4])
assert([ a for a in ancestors(3,partial_tree) ] == [4,5,None])
assert([ a for a in ancestors(17,multirooted_tree) ] == [])
def test_contains_cycle():
tree = [1,2,2,4,5,2,2]
cyclic_tree = [1,2,2,4,5,3,2]
partial_tree = [1,2,2,4,5,None,2]
multirooted_tree = [3,2,0,3,3,7,7,3,7,10,7,10,11,12,18,16,18,17,12,3]
assert(contains_cycle(tree) == None)
assert(contains_cycle(cyclic_tree) == set([3,4,5]))
assert(contains_cycle(partial_tree) == None)
assert(contains_cycle(multirooted_tree) == None)
def test_is_nonproj_arc():
nonproj_tree = [1,2,2,4,5,2,7,4,2]
partial_tree = [1,2,2,4,5,None,7,4,2]
multirooted_tree = [3,2,0,3,3,7,7,3,7,10,7,10,11,12,18,16,18,17,12,3]
assert(is_nonproj_arc(0,nonproj_tree) == False)
assert(is_nonproj_arc(1,nonproj_tree) == False)
assert(is_nonproj_arc(2,nonproj_tree) == False)
assert(is_nonproj_arc(3,nonproj_tree) == False)
assert(is_nonproj_arc(4,nonproj_tree) == False)
assert(is_nonproj_arc(5,nonproj_tree) == False)
assert(is_nonproj_arc(6,nonproj_tree) == False)
assert(is_nonproj_arc(7,nonproj_tree) == True)
assert(is_nonproj_arc(8,nonproj_tree) == False)
assert(is_nonproj_arc(7,partial_tree) == False)
assert(is_nonproj_arc(17,multirooted_tree) == False)
assert(is_nonproj_arc(16,multirooted_tree) == True)
def test_is_nonproj_tree():
proj_tree = [1,2,2,4,5,2,7,5,2]
nonproj_tree = [1,2,2,4,5,2,7,4,2]
partial_tree = [1,2,2,4,5,None,7,4,2]
multirooted_tree = [3,2,0,3,3,7,7,3,7,10,7,10,11,12,18,16,18,17,12,3]
assert(is_nonproj_tree(proj_tree) == False)
assert(is_nonproj_tree(nonproj_tree) == True)
assert(is_nonproj_tree(partial_tree) == False)
assert(is_nonproj_tree(multirooted_tree) == True)
def deprojectivize(proj_heads, deco_labels, EN):
slen = len(proj_heads)
sent = EN.tokenizer.tokens_from_list(['whatever'] * slen)
rel_proj_heads = [ head-i for i,head in enumerate(proj_heads) ]
labelids = [ EN.vocab.strings[label] for label in deco_labels ]
parse = numpy.asarray(zip(rel_proj_heads,labelids), dtype=numpy.int32)
sent.from_array([HEAD,DEP],parse)
PseudoProjectivity.deprojectivize(sent)
parse = sent.to_array([HEAD,DEP])
deproj_heads = [ i+head for i,head in enumerate(parse[:,0]) ]
undeco_labels = [ EN.vocab.strings[int(labelid)] for labelid in parse[:,1] ]
return deproj_heads, undeco_labels
@pytest.mark.models
def test_pseudoprojectivity(EN):
tree = [1,2,2]
nonproj_tree = [1,2,2,4,5,2,7,4,2]
labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct']
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
labels2 = ['advmod','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct']
assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
assert(PseudoProjectivity.decompose('X') == ('X',''))
assert(PseudoProjectivity.is_decorated('X||Y') == True)
assert(PseudoProjectivity.is_decorated('X') == False)
PseudoProjectivity._lift(0,tree)
assert(tree == [2,2,2])
np_arc = PseudoProjectivity._get_smallest_nonproj_arc(nonproj_tree)
assert(np_arc == 7)
np_arc = PseudoProjectivity._get_smallest_nonproj_arc(nonproj_tree2)
assert(np_arc == 10)
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
assert(proj_heads == [1,2,2,4,5,2,7,5,2])
assert(deco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl||dobj','punct'])
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == nonproj_tree)
assert(undeco_labels == labels)
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
assert(deco_labels == ['advmod||aux','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct'])
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == nonproj_tree2)
assert(undeco_labels == labels2)
# if decoration is wrong such that there is no head with the desired label
# the structure is kept and the label is undecorated
proj_heads = [1,2,2,4,5,2,7,5,2]
deco_labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl||iobj','punct']
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == proj_heads)
assert(undeco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct'])
# if there are two potential new heads, the first one is chosen even if it's wrong
proj_heads = [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1]
deco_labels = ['advmod||aux','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct']
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
assert(undeco_labels == ['advmod','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct'])

View File

@ -201,17 +201,9 @@ cdef class Token:
cdef int nr_iter = 0 cdef int nr_iter = 0
cdef const TokenC* ptr = self.c - (self.i - self.c.l_edge) cdef const TokenC* ptr = self.c - (self.i - self.c.l_edge)
while ptr < self.c: while ptr < self.c:
# If this head is still to the right of us, we can skip to it if ptr + ptr.head == self.c:
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < self.c:
ptr += ptr.head
elif ptr + ptr.head == self.c:
yield self.doc[ptr - (self.c - self.i)] yield self.doc[ptr - (self.c - self.i)]
ptr += 1 ptr += 1
else:
ptr += 1
nr_iter += 1 nr_iter += 1
# This is ugly, but it's a way to guard out infinite loops # This is ugly, but it's a way to guard out infinite loops
if nr_iter >= 10000000: if nr_iter >= 10000000:
@ -226,16 +218,10 @@ cdef class Token:
tokens = [] tokens = []
cdef int nr_iter = 0 cdef int nr_iter = 0
while ptr > self.c: while ptr > self.c:
# If this head is still to the right of us, we can skip to it if ptr + ptr.head == self.c:
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > self.c):
ptr += ptr.head
elif ptr + ptr.head == self.c:
tokens.append(self.doc[ptr - (self.c - self.i)]) tokens.append(self.doc[ptr - (self.c - self.i)])
ptr -= 1 ptr -= 1
else: nr_iter += 1
ptr -= 1
if nr_iter >= 10000000: if nr_iter >= 10000000:
raise RuntimeError( raise RuntimeError(
"Possibly infinite loop encountered while looking for token.rights") "Possibly infinite loop encountered while looking for token.rights")