integrated pseudo-projective parsing into parser

- nonproj.pyx holds a class PseudoProjectivity which currently holds
  all functionality to implement Nivre & Nilsson 2005's pseudo-projective
  parsing using the HEAD decoration scheme
- changed lefts/rights in Token to account for possible non-projective
  structures
This commit is contained in:
Wolfgang Seeker 2016-03-01 10:09:08 +01:00
parent 56b7210e82
commit 3448cb40a4
8 changed files with 120 additions and 101 deletions

View File

@ -47,6 +47,7 @@ MOD_NAMES = [
'spacy.syntax._state',
'spacy.tokenizer',
'spacy.syntax.parser',
'spacy.syntax.nonproj',
'spacy.syntax.transition_system',
'spacy.syntax.arc_eager',
'spacy.syntax._parse_features',

View File

@ -14,7 +14,7 @@ try:
except ImportError:
import json
import nonproj
from .syntax import nonproj
def tags_to_entities(tags):

0
spacy/syntax/nonproj.pxd Normal file
View File

View File

@ -1,6 +1,10 @@
from copy import copy
from collections import Counter
from ..tokens.doc cimport Doc
from spacy.attrs import DEP, HEAD
def ancestors(tokenid, heads):
# returns all words going from the word up the path to the root
# the path to root cannot be longer than the number of words in the sentence
@ -55,69 +59,90 @@ def is_nonproj_tree(heads):
return any( is_nonproj_arc(word,heads) for word in range(len(heads)) )
class PseudoProjective:
cdef class PseudoProjectivity:
# implements the projectivize/deprojectivize mechanism in Nivre & Nilsson 2005
# for doing pseudo-projective parsing
# implementation uses the HEAD decoration scheme
def preprocess_training_data(self, labeled_trees, label_freq_cutoff=30):
# expects a sequence of pairs of head arrays and labels
delimiter = '||'
@classmethod
def decompose(cls, label):
return label.partition(cls.delimiter)[::2]
@classmethod
def is_decorated(cls, label):
return label.find(cls.delimiter) != -1
@classmethod
def preprocess_training_data(cls, gold_tuples, label_freq_cutoff=30):
preprocessed = []
for heads,labels in labeled_trees:
proj_heads,deco_labels = self.projectivize(heads,labels)
# set the label to ROOT for each root dependent
deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ]
preprocessed.append((proj_heads,deco_labels))
freqs = Counter()
for raw_text, sents in gold_tuples:
prepro_sents = []
for (ids, words, tags, heads, labels, iob), ctnts in sents:
proj_heads,deco_labels = cls.projectivize(heads,labels)
# set the label to ROOT for each root dependent
deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ]
# count label frequencies
if label_freq_cutoff > 0:
freqs.update( label for label in deco_labels if cls.is_decorated(label) )
prepro_sents.append(((ids,words,tags,proj_heads,deco_labels,iob), ctnts))
preprocessed.append((raw_text, prepro_sents))
if label_freq_cutoff > 0:
return self._filter_labels(preprocessed,label_freq_cutoff)
return cls._filter_labels(preprocessed,label_freq_cutoff,freqs)
return preprocessed
def projectivize(self, heads, labels):
@classmethod
def projectivize(cls, heads, labels):
# use the algorithm by Nivre & Nilsson 2005
# assumes heads to be a proper tree, i.e. connected and cycle-free
# returns a new pair (heads,labels) which encode
# a projective and decorated tree
proj_heads = copy(heads)
smallest_np_arc = self._get_smallest_nonproj_arc(proj_heads)
smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads)
if smallest_np_arc == None: # this sentence is already projective
return proj_heads, copy(labels)
while smallest_np_arc != None:
self._lift(smallest_np_arc, proj_heads)
smallest_np_arc = self._get_smallest_nonproj_arc(proj_heads)
deco_labels = self._decorate(heads, proj_heads, labels)
cls._lift(smallest_np_arc, proj_heads)
smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads)
deco_labels = cls._decorate(heads, proj_heads, labels)
return proj_heads, deco_labels
def deprojectivize(self, heads, labels):
@classmethod
def deprojectivize(cls, Doc tokens):
# reattach arcs with decorated labels (following HEAD scheme)
# for each decorated arc X||Y, search top-down, left-to-right,
# breadth-first until hitting a Y then make this the new head
newheads, newlabels = copy(heads), copy(labels)
spans = None
for tokenid, head in enumerate(heads):
if labels[tokenid].find('||') != -1:
newlabel,_,headlabel = labels[tokenid].partition('||')
newhead = self._find_new_head(head,tokenid,headlabel,heads,labels,spans=spans)
newheads[tokenid] = newhead
newlabels[tokenid] = newlabel
return newheads, newlabels
parse = tokens.to_array([HEAD, DEP])
labels = [ tokens.vocab.strings[int(p[1])] for p in parse ]
for token in tokens:
if cls.is_decorated(token.dep_):
newlabel,headlabel = cls.decompose(token.dep_)
newhead = cls._find_new_head(token,headlabel)
parse[token.i,1] = tokens.vocab.strings[newlabel]
parse[token.i,0] = newhead.i - token.i
tokens.from_array([HEAD, DEP],parse)
def _decorate(self, heads, proj_heads, labels):
@classmethod
def _decorate(cls, heads, proj_heads, labels):
# uses decoration scheme HEAD from Nivre & Nilsson 2005
assert(len(heads) == len(proj_heads) == len(labels))
deco_labels = []
for tokenid,head in enumerate(heads):
if head != proj_heads[tokenid]:
deco_labels.append('%s||%s' % (labels[tokenid],labels[head]))
deco_labels.append('%s%s%s' % (labels[tokenid],cls.delimiter,labels[head]))
else:
deco_labels.append(labels[tokenid])
return deco_labels
def _get_smallest_nonproj_arc(self, heads):
@classmethod
def _get_smallest_nonproj_arc(cls, heads):
# return the smallest non-proj arc or None
# where size is defined as the distance between dep and head
# and ties are broken left to right
@ -131,7 +156,8 @@ class PseudoProjective:
return smallest_np_arc
def _lift(self, tokenid, heads):
@classmethod
def _lift(cls, tokenid, heads):
# reattaches a word to it's grandfather
head = heads[tokenid]
ghead = heads[head]
@ -139,43 +165,36 @@ class PseudoProjective:
heads[tokenid] = ghead if head != ghead else tokenid
def _find_new_head(self, rootid, tokenid, headlabel, heads, labels, spans=None):
@classmethod
def _find_new_head(cls, token, headlabel):
# search through the tree starting from root
# returns the id of the first descendant with the given label
# if there is none, return the current head (no change)
if not spans:
spans = self._make_span_index(heads)
queue = spans.get(rootid,[])
queue.remove(tokenid) # don't search in the subtree of the nonproj arc
queue = [token.head]
while queue:
next_queue = []
for idx in queue:
if labels[idx] == headlabel:
return idx
next_queue.extend(spans.get(idx,[]))
for qtoken in queue:
for child in qtoken.children:
if child == token:
continue
if child.dep_ == headlabel:
return child
next_queue.append(child)
queue = next_queue
return heads[tokenid]
return token.head
def _make_span_index(self, heads):
# stores the direct dependents for each token
# for searching top-down through a tree
spans = {}
for tokenid, head in enumerate(heads):
if tokenid == head: # root
continue
if head not in spans:
spans[head] = []
spans[head].append(tokenid)
return spans
def _filter_labels(self, labeled_trees, cutoff):
@classmethod
def _filter_labels(cls, gold_tuples, cutoff, freqs):
# throw away infrequent decorated labels
# can't learn them reliably anyway and keeps label set smaller
freqs = Counter([ label for _,labels in labeled_trees for label in labels if label.find('||') != -1 ])
filtered = []
for proj_heads,deco_labels in labeled_trees:
filtered_labels = [ label.partition('||')[0] if freqs.get(label,cutoff) < cutoff else label for label in deco_labels ]
filtered.append((proj_heads,filtered_labels))
for raw_text, sents in gold_tuples:
filtered_sents = []
for (ids, words, tags, heads, labels, iob), ctnts in sents:
filtered_labels = [ cls.decompose(label)[0] if freqs.get(label,cutoff) < cutoff else label for label in labels ]
filtered_sents.append(((ids,words,tags,heads,filtered_labels,iob), ctnts))
filtered.append((raw_text, filtered_sents))
return filtered

View File

@ -15,5 +15,6 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser:
cdef readonly ParserModel model
cdef readonly TransitionSystem moves
cdef int _projectivize
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil

View File

@ -17,6 +17,7 @@ from os import path
import shutil
import json
import sys
from .nonproj import PseudoProjectivity
from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64
@ -78,9 +79,10 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser:
def __init__(self, StringStore strings, transition_system, ParserModel model):
def __init__(self, StringStore strings, transition_system, ParserModel model, int projectivize = 0):
self.moves = transition_system
self.model = model
self._projectivize = projectivize
@classmethod
def from_dir(cls, model_dir, strings, transition_system):
@ -94,7 +96,7 @@ cdef class Parser:
model = ParserModel(templates)
if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model)
return cls(strings, moves, model, cfg.projectivize)
@classmethod
def load(cls, pkg_or_str_or_file, vocab):
@ -113,6 +115,9 @@ cdef class Parser:
tokens.is_parsed = True
# Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals()
# projectivize output
if self._projectivize:
PseudoProjectivity.deprojectivize(tokens)
def pipe(self, stream, int batch_size=1000, int n_threads=2):
cdef Pool mem = Pool()

View File

@ -1,7 +1,13 @@
from __future__ import unicode_literals
import pytest
from spacy.nonproj import ancestors, contains_cycle, is_nonproj_arc, is_nonproj_tree, PseudoProjective
from spacy.tokens.doc import Doc
from spacy.vocab import Vocab
from spacy.tokenizer import Tokenizer
from spacy.attrs import DEP, HEAD
import numpy
from spacy.syntax.nonproj import ancestors, contains_cycle, is_nonproj_arc, is_nonproj_tree, PseudoProjectivity
def test_ancestors():
tree = [1,2,2,4,5,2,2]
@ -50,52 +56,53 @@ def test_is_nonproj_tree():
assert(is_nonproj_tree(partial_tree) == False)
assert(is_nonproj_tree(multirooted_tree) == True)
def test_pseudoprojective():
def test_pseudoprojectivity():
tree = [1,2,2]
nonproj_tree = [1,2,2,4,5,2,7,4,2]
labels = ['NK','SB','ROOT','NK','OA','OC','SB','RC','--']
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
labels2 = ['MO','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--']
pp = PseudoProjective()
assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
assert(PseudoProjectivity.decompose('X') == ('X',''))
assert(pp._make_span_index(tree) == { 1:[0], 2:[1] })
assert(pp._make_span_index(nonproj_tree) == { 1:[0], 2:[1,5,8], 4:[3,7], 5:[4], 7:[6] })
assert(PseudoProjectivity.is_decorated('X||Y') == True)
assert(PseudoProjectivity.is_decorated('X') == False)
pp._lift(0,tree)
PseudoProjectivity._lift(0,tree)
assert(tree == [2,2,2])
np_arc = pp._get_smallest_nonproj_arc(nonproj_tree)
np_arc = PseudoProjectivity._get_smallest_nonproj_arc(nonproj_tree)
assert(np_arc == 7)
np_arc = pp._get_smallest_nonproj_arc(nonproj_tree2)
np_arc = PseudoProjectivity._get_smallest_nonproj_arc(nonproj_tree2)
assert(np_arc == 10)
proj_heads, deco_labels = pp.projectivize(nonproj_tree,labels)
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
assert(proj_heads == [1,2,2,4,5,2,7,5,2])
assert(deco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC||OA','--'])
deproj_heads, undeco_labels = pp.deprojectivize(proj_heads,deco_labels)
assert(deproj_heads == nonproj_tree)
assert(undeco_labels == labels)
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
# assert(deproj_heads == nonproj_tree)
# assert(undeco_labels == labels)
proj_heads, deco_labels = pp.projectivize(nonproj_tree2,labels2)
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
assert(deco_labels == ['MO||OC','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
deproj_heads, undeco_labels = pp.deprojectivize(proj_heads,deco_labels)
assert(deproj_heads == nonproj_tree2)
assert(undeco_labels == labels2)
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
# assert(deproj_heads == nonproj_tree2)
# assert(undeco_labels == labels2)
# if decoration is wrong such that there is no head with the desired label
# the structure is kept and the label is undecorated
deproj_heads, undeco_labels = pp.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--'])
assert(deproj_heads == [1,2,2,4,5,2,7,5,2])
assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'])
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--'])
# assert(deproj_heads == [1,2,2,4,5,2,7,5,2])
# assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'])
# if there are two potential new heads, the first one is chosen even if it's wrong
deproj_heads, undeco_labels = pp.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \
['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'])
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \
# ['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
# assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
# assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'])

View File

@ -201,17 +201,9 @@ cdef class Token:
cdef int nr_iter = 0
cdef const TokenC* ptr = self.c - (self.i - self.c.l_edge)
while ptr < self.c:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < self.c:
ptr += ptr.head
elif ptr + ptr.head == self.c:
if ptr + ptr.head == self.c:
yield self.doc[ptr - (self.c - self.i)]
ptr += 1
else:
ptr += 1
ptr += 1
nr_iter += 1
# This is ugly, but it's a way to guard out infinite loops
if nr_iter >= 10000000:
@ -226,16 +218,10 @@ cdef class Token:
tokens = []
cdef int nr_iter = 0
while ptr > self.c:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > self.c):
ptr += ptr.head
elif ptr + ptr.head == self.c:
if ptr + ptr.head == self.c:
tokens.append(self.doc[ptr - (self.c - self.i)])
ptr -= 1
else:
ptr -= 1
ptr -= 1
nr_iter += 1
if nr_iter >= 10000000:
raise RuntimeError(
"Possibly infinite loop encountered while looking for token.rights")