integrated pseudo-projective parsing into parser

- nonproj.pyx holds a class PseudoProjectivity which currently holds
  all functionality to implement Nivre & Nilsson 2005's pseudo-projective
  parsing using the HEAD decoration scheme
- changed lefts/rights in Token to account for possible non-projective
  structures
This commit is contained in:
Wolfgang Seeker 2016-03-01 10:09:08 +01:00
parent 56b7210e82
commit 3448cb40a4
8 changed files with 120 additions and 101 deletions

View File

@ -47,6 +47,7 @@ MOD_NAMES = [
'spacy.syntax._state', 'spacy.syntax._state',
'spacy.tokenizer', 'spacy.tokenizer',
'spacy.syntax.parser', 'spacy.syntax.parser',
'spacy.syntax.nonproj',
'spacy.syntax.transition_system', 'spacy.syntax.transition_system',
'spacy.syntax.arc_eager', 'spacy.syntax.arc_eager',
'spacy.syntax._parse_features', 'spacy.syntax._parse_features',

View File

@ -14,7 +14,7 @@ try:
except ImportError: except ImportError:
import json import json
import nonproj from .syntax import nonproj
def tags_to_entities(tags): def tags_to_entities(tags):

0
spacy/syntax/nonproj.pxd Normal file
View File

View File

@ -1,6 +1,10 @@
from copy import copy from copy import copy
from collections import Counter from collections import Counter
from ..tokens.doc cimport Doc
from spacy.attrs import DEP, HEAD
def ancestors(tokenid, heads): def ancestors(tokenid, heads):
# returns all words going from the word up the path to the root # returns all words going from the word up the path to the root
# the path to root cannot be longer than the number of words in the sentence # the path to root cannot be longer than the number of words in the sentence
@ -55,69 +59,90 @@ def is_nonproj_tree(heads):
return any( is_nonproj_arc(word,heads) for word in range(len(heads)) ) return any( is_nonproj_arc(word,heads) for word in range(len(heads)) )
class PseudoProjective: cdef class PseudoProjectivity:
# implements the projectivize/deprojectivize mechanism in Nivre & Nilsson 2005 # implements the projectivize/deprojectivize mechanism in Nivre & Nilsson 2005
# for doing pseudo-projective parsing # for doing pseudo-projective parsing
# implementation uses the HEAD decoration scheme # implementation uses the HEAD decoration scheme
def preprocess_training_data(self, labeled_trees, label_freq_cutoff=30): delimiter = '||'
# expects a sequence of pairs of head arrays and labels
@classmethod
def decompose(cls, label):
return label.partition(cls.delimiter)[::2]
@classmethod
def is_decorated(cls, label):
return label.find(cls.delimiter) != -1
@classmethod
def preprocess_training_data(cls, gold_tuples, label_freq_cutoff=30):
preprocessed = [] preprocessed = []
for heads,labels in labeled_trees: freqs = Counter()
proj_heads,deco_labels = self.projectivize(heads,labels) for raw_text, sents in gold_tuples:
# set the label to ROOT for each root dependent prepro_sents = []
deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ] for (ids, words, tags, heads, labels, iob), ctnts in sents:
preprocessed.append((proj_heads,deco_labels)) proj_heads,deco_labels = cls.projectivize(heads,labels)
# set the label to ROOT for each root dependent
deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ]
# count label frequencies
if label_freq_cutoff > 0:
freqs.update( label for label in deco_labels if cls.is_decorated(label) )
prepro_sents.append(((ids,words,tags,proj_heads,deco_labels,iob), ctnts))
preprocessed.append((raw_text, prepro_sents))
if label_freq_cutoff > 0: if label_freq_cutoff > 0:
return self._filter_labels(preprocessed,label_freq_cutoff) return cls._filter_labels(preprocessed,label_freq_cutoff,freqs)
return preprocessed return preprocessed
def projectivize(self, heads, labels): @classmethod
def projectivize(cls, heads, labels):
# use the algorithm by Nivre & Nilsson 2005 # use the algorithm by Nivre & Nilsson 2005
# assumes heads to be a proper tree, i.e. connected and cycle-free # assumes heads to be a proper tree, i.e. connected and cycle-free
# returns a new pair (heads,labels) which encode # returns a new pair (heads,labels) which encode
# a projective and decorated tree # a projective and decorated tree
proj_heads = copy(heads) proj_heads = copy(heads)
smallest_np_arc = self._get_smallest_nonproj_arc(proj_heads) smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads)
if smallest_np_arc == None: # this sentence is already projective if smallest_np_arc == None: # this sentence is already projective
return proj_heads, copy(labels) return proj_heads, copy(labels)
while smallest_np_arc != None: while smallest_np_arc != None:
self._lift(smallest_np_arc, proj_heads) cls._lift(smallest_np_arc, proj_heads)
smallest_np_arc = self._get_smallest_nonproj_arc(proj_heads) smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads)
deco_labels = self._decorate(heads, proj_heads, labels) deco_labels = cls._decorate(heads, proj_heads, labels)
return proj_heads, deco_labels return proj_heads, deco_labels
def deprojectivize(self, heads, labels): @classmethod
def deprojectivize(cls, Doc tokens):
# reattach arcs with decorated labels (following HEAD scheme) # reattach arcs with decorated labels (following HEAD scheme)
# for each decorated arc X||Y, search top-down, left-to-right, # for each decorated arc X||Y, search top-down, left-to-right,
# breadth-first until hitting a Y then make this the new head # breadth-first until hitting a Y then make this the new head
newheads, newlabels = copy(heads), copy(labels) parse = tokens.to_array([HEAD, DEP])
spans = None labels = [ tokens.vocab.strings[int(p[1])] for p in parse ]
for tokenid, head in enumerate(heads): for token in tokens:
if labels[tokenid].find('||') != -1: if cls.is_decorated(token.dep_):
newlabel,_,headlabel = labels[tokenid].partition('||') newlabel,headlabel = cls.decompose(token.dep_)
newhead = self._find_new_head(head,tokenid,headlabel,heads,labels,spans=spans) newhead = cls._find_new_head(token,headlabel)
newheads[tokenid] = newhead parse[token.i,1] = tokens.vocab.strings[newlabel]
newlabels[tokenid] = newlabel parse[token.i,0] = newhead.i - token.i
return newheads, newlabels tokens.from_array([HEAD, DEP],parse)
def _decorate(self, heads, proj_heads, labels): @classmethod
def _decorate(cls, heads, proj_heads, labels):
# uses decoration scheme HEAD from Nivre & Nilsson 2005 # uses decoration scheme HEAD from Nivre & Nilsson 2005
assert(len(heads) == len(proj_heads) == len(labels)) assert(len(heads) == len(proj_heads) == len(labels))
deco_labels = [] deco_labels = []
for tokenid,head in enumerate(heads): for tokenid,head in enumerate(heads):
if head != proj_heads[tokenid]: if head != proj_heads[tokenid]:
deco_labels.append('%s||%s' % (labels[tokenid],labels[head])) deco_labels.append('%s%s%s' % (labels[tokenid],cls.delimiter,labels[head]))
else: else:
deco_labels.append(labels[tokenid]) deco_labels.append(labels[tokenid])
return deco_labels return deco_labels
def _get_smallest_nonproj_arc(self, heads): @classmethod
def _get_smallest_nonproj_arc(cls, heads):
# return the smallest non-proj arc or None # return the smallest non-proj arc or None
# where size is defined as the distance between dep and head # where size is defined as the distance between dep and head
# and ties are broken left to right # and ties are broken left to right
@ -131,7 +156,8 @@ class PseudoProjective:
return smallest_np_arc return smallest_np_arc
def _lift(self, tokenid, heads): @classmethod
def _lift(cls, tokenid, heads):
# reattaches a word to it's grandfather # reattaches a word to it's grandfather
head = heads[tokenid] head = heads[tokenid]
ghead = heads[head] ghead = heads[head]
@ -139,43 +165,36 @@ class PseudoProjective:
heads[tokenid] = ghead if head != ghead else tokenid heads[tokenid] = ghead if head != ghead else tokenid
def _find_new_head(self, rootid, tokenid, headlabel, heads, labels, spans=None): @classmethod
def _find_new_head(cls, token, headlabel):
# search through the tree starting from root # search through the tree starting from root
# returns the id of the first descendant with the given label # returns the id of the first descendant with the given label
# if there is none, return the current head (no change) # if there is none, return the current head (no change)
if not spans: queue = [token.head]
spans = self._make_span_index(heads)
queue = spans.get(rootid,[])
queue.remove(tokenid) # don't search in the subtree of the nonproj arc
while queue: while queue:
next_queue = [] next_queue = []
for idx in queue: for qtoken in queue:
if labels[idx] == headlabel: for child in qtoken.children:
return idx if child == token:
next_queue.extend(spans.get(idx,[])) continue
if child.dep_ == headlabel:
return child
next_queue.append(child)
queue = next_queue queue = next_queue
return heads[tokenid] return token.head
def _make_span_index(self, heads): @classmethod
# stores the direct dependents for each token def _filter_labels(cls, gold_tuples, cutoff, freqs):
# for searching top-down through a tree
spans = {}
for tokenid, head in enumerate(heads):
if tokenid == head: # root
continue
if head not in spans:
spans[head] = []
spans[head].append(tokenid)
return spans
def _filter_labels(self, labeled_trees, cutoff):
# throw away infrequent decorated labels # throw away infrequent decorated labels
# can't learn them reliably anyway and keeps label set smaller # can't learn them reliably anyway and keeps label set smaller
freqs = Counter([ label for _,labels in labeled_trees for label in labels if label.find('||') != -1 ])
filtered = [] filtered = []
for proj_heads,deco_labels in labeled_trees: for raw_text, sents in gold_tuples:
filtered_labels = [ label.partition('||')[0] if freqs.get(label,cutoff) < cutoff else label for label in deco_labels ] filtered_sents = []
filtered.append((proj_heads,filtered_labels)) for (ids, words, tags, heads, labels, iob), ctnts in sents:
filtered_labels = [ cls.decompose(label)[0] if freqs.get(label,cutoff) < cutoff else label for label in labels ]
filtered_sents.append(((ids,words,tags,heads,filtered_labels,iob), ctnts))
filtered.append((raw_text, filtered_sents))
return filtered return filtered

View File

@ -15,5 +15,6 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser: cdef class Parser:
cdef readonly ParserModel model cdef readonly ParserModel model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef int _projectivize
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil

View File

@ -17,6 +17,7 @@ from os import path
import shutil import shutil
import json import json
import sys import sys
from .nonproj import PseudoProjectivity
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
@ -78,9 +79,10 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser: cdef class Parser:
def __init__(self, StringStore strings, transition_system, ParserModel model): def __init__(self, StringStore strings, transition_system, ParserModel model, int projectivize = 0):
self.moves = transition_system self.moves = transition_system
self.model = model self.model = model
self._projectivize = projectivize
@classmethod @classmethod
def from_dir(cls, model_dir, strings, transition_system): def from_dir(cls, model_dir, strings, transition_system):
@ -94,7 +96,7 @@ cdef class Parser:
model = ParserModel(templates) model = ParserModel(templates)
if path.exists(path.join(model_dir, 'model')): if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model')) model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model) return cls(strings, moves, model, cfg.projectivize)
@classmethod @classmethod
def load(cls, pkg_or_str_or_file, vocab): def load(cls, pkg_or_str_or_file, vocab):
@ -113,6 +115,9 @@ cdef class Parser:
tokens.is_parsed = True tokens.is_parsed = True
# Check for KeyboardInterrupt etc. Untested # Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals() PyErr_CheckSignals()
# projectivize output
if self._projectivize:
PseudoProjectivity.deprojectivize(tokens)
def pipe(self, stream, int batch_size=1000, int n_threads=2): def pipe(self, stream, int batch_size=1000, int n_threads=2):
cdef Pool mem = Pool() cdef Pool mem = Pool()

View File

@ -1,7 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest import pytest
from spacy.nonproj import ancestors, contains_cycle, is_nonproj_arc, is_nonproj_tree, PseudoProjective from spacy.tokens.doc import Doc
from spacy.vocab import Vocab
from spacy.tokenizer import Tokenizer
from spacy.attrs import DEP, HEAD
import numpy
from spacy.syntax.nonproj import ancestors, contains_cycle, is_nonproj_arc, is_nonproj_tree, PseudoProjectivity
def test_ancestors(): def test_ancestors():
tree = [1,2,2,4,5,2,2] tree = [1,2,2,4,5,2,2]
@ -50,52 +56,53 @@ def test_is_nonproj_tree():
assert(is_nonproj_tree(partial_tree) == False) assert(is_nonproj_tree(partial_tree) == False)
assert(is_nonproj_tree(multirooted_tree) == True) assert(is_nonproj_tree(multirooted_tree) == True)
def test_pseudoprojective(): def test_pseudoprojectivity():
tree = [1,2,2] tree = [1,2,2]
nonproj_tree = [1,2,2,4,5,2,7,4,2] nonproj_tree = [1,2,2,4,5,2,7,4,2]
labels = ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'] labels = ['NK','SB','ROOT','NK','OA','OC','SB','RC','--']
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1] nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
labels2 = ['MO','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'] labels2 = ['MO','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--']
pp = PseudoProjective() assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
assert(PseudoProjectivity.decompose('X') == ('X',''))
assert(pp._make_span_index(tree) == { 1:[0], 2:[1] }) assert(PseudoProjectivity.is_decorated('X||Y') == True)
assert(pp._make_span_index(nonproj_tree) == { 1:[0], 2:[1,5,8], 4:[3,7], 5:[4], 7:[6] }) assert(PseudoProjectivity.is_decorated('X') == False)
pp._lift(0,tree) PseudoProjectivity._lift(0,tree)
assert(tree == [2,2,2]) assert(tree == [2,2,2])
np_arc = pp._get_smallest_nonproj_arc(nonproj_tree) np_arc = PseudoProjectivity._get_smallest_nonproj_arc(nonproj_tree)
assert(np_arc == 7) assert(np_arc == 7)
np_arc = pp._get_smallest_nonproj_arc(nonproj_tree2) np_arc = PseudoProjectivity._get_smallest_nonproj_arc(nonproj_tree2)
assert(np_arc == 10) assert(np_arc == 10)
proj_heads, deco_labels = pp.projectivize(nonproj_tree,labels) proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
assert(proj_heads == [1,2,2,4,5,2,7,5,2]) assert(proj_heads == [1,2,2,4,5,2,7,5,2])
assert(deco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC||OA','--']) assert(deco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC||OA','--'])
deproj_heads, undeco_labels = pp.deprojectivize(proj_heads,deco_labels) # deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
assert(deproj_heads == nonproj_tree) # assert(deproj_heads == nonproj_tree)
assert(undeco_labels == labels) # assert(undeco_labels == labels)
proj_heads, deco_labels = pp.projectivize(nonproj_tree2,labels2) proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1]) assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
assert(deco_labels == ['MO||OC','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--']) assert(deco_labels == ['MO||OC','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
deproj_heads, undeco_labels = pp.deprojectivize(proj_heads,deco_labels) # deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
assert(deproj_heads == nonproj_tree2) # assert(deproj_heads == nonproj_tree2)
assert(undeco_labels == labels2) # assert(undeco_labels == labels2)
# if decoration is wrong such that there is no head with the desired label # if decoration is wrong such that there is no head with the desired label
# the structure is kept and the label is undecorated # the structure is kept and the label is undecorated
deproj_heads, undeco_labels = pp.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--']) # deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--'])
assert(deproj_heads == [1,2,2,4,5,2,7,5,2]) # assert(deproj_heads == [1,2,2,4,5,2,7,5,2])
assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--']) # assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'])
# if there are two potential new heads, the first one is chosen even if it's wrong # if there are two potential new heads, the first one is chosen even if it's wrong
deproj_heads, undeco_labels = pp.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \ # deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \
['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--']) # ['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1]) # assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--']) # assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'])

View File

@ -201,17 +201,9 @@ cdef class Token:
cdef int nr_iter = 0 cdef int nr_iter = 0
cdef const TokenC* ptr = self.c - (self.i - self.c.l_edge) cdef const TokenC* ptr = self.c - (self.i - self.c.l_edge)
while ptr < self.c: while ptr < self.c:
# If this head is still to the right of us, we can skip to it if ptr + ptr.head == self.c:
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < self.c:
ptr += ptr.head
elif ptr + ptr.head == self.c:
yield self.doc[ptr - (self.c - self.i)] yield self.doc[ptr - (self.c - self.i)]
ptr += 1 ptr += 1
else:
ptr += 1
nr_iter += 1 nr_iter += 1
# This is ugly, but it's a way to guard out infinite loops # This is ugly, but it's a way to guard out infinite loops
if nr_iter >= 10000000: if nr_iter >= 10000000:
@ -226,16 +218,10 @@ cdef class Token:
tokens = [] tokens = []
cdef int nr_iter = 0 cdef int nr_iter = 0
while ptr > self.c: while ptr > self.c:
# If this head is still to the right of us, we can skip to it if ptr + ptr.head == self.c:
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > self.c):
ptr += ptr.head
elif ptr + ptr.head == self.c:
tokens.append(self.doc[ptr - (self.c - self.i)]) tokens.append(self.doc[ptr - (self.c - self.i)])
ptr -= 1 ptr -= 1
else: nr_iter += 1
ptr -= 1
if nr_iter >= 10000000: if nr_iter >= 10000000:
raise RuntimeError( raise RuntimeError(
"Possibly infinite loop encountered while looking for token.rights") "Possibly infinite loop encountered while looking for token.rights")