mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
integrated pseudo-projective parsing into parser
- nonproj.pyx holds a class PseudoProjectivity which currently holds all functionality to implement Nivre & Nilsson 2005's pseudo-projective parsing using the HEAD decoration scheme - changed lefts/rights in Token to account for possible non-projective structures
This commit is contained in:
parent
56b7210e82
commit
3448cb40a4
1
setup.py
1
setup.py
|
@ -47,6 +47,7 @@ MOD_NAMES = [
|
|||
'spacy.syntax._state',
|
||||
'spacy.tokenizer',
|
||||
'spacy.syntax.parser',
|
||||
'spacy.syntax.nonproj',
|
||||
'spacy.syntax.transition_system',
|
||||
'spacy.syntax.arc_eager',
|
||||
'spacy.syntax._parse_features',
|
||||
|
|
|
@ -14,7 +14,7 @@ try:
|
|||
except ImportError:
|
||||
import json
|
||||
|
||||
import nonproj
|
||||
from .syntax import nonproj
|
||||
|
||||
|
||||
def tags_to_entities(tags):
|
||||
|
|
0
spacy/syntax/nonproj.pxd
Normal file
0
spacy/syntax/nonproj.pxd
Normal file
|
@ -1,6 +1,10 @@
|
|||
from copy import copy
|
||||
from collections import Counter
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from spacy.attrs import DEP, HEAD
|
||||
|
||||
|
||||
def ancestors(tokenid, heads):
|
||||
# returns all words going from the word up the path to the root
|
||||
# the path to root cannot be longer than the number of words in the sentence
|
||||
|
@ -55,69 +59,90 @@ def is_nonproj_tree(heads):
|
|||
return any( is_nonproj_arc(word,heads) for word in range(len(heads)) )
|
||||
|
||||
|
||||
class PseudoProjective:
|
||||
cdef class PseudoProjectivity:
|
||||
# implements the projectivize/deprojectivize mechanism in Nivre & Nilsson 2005
|
||||
# for doing pseudo-projective parsing
|
||||
# implementation uses the HEAD decoration scheme
|
||||
|
||||
def preprocess_training_data(self, labeled_trees, label_freq_cutoff=30):
|
||||
# expects a sequence of pairs of head arrays and labels
|
||||
delimiter = '||'
|
||||
|
||||
@classmethod
|
||||
def decompose(cls, label):
|
||||
return label.partition(cls.delimiter)[::2]
|
||||
|
||||
@classmethod
|
||||
def is_decorated(cls, label):
|
||||
return label.find(cls.delimiter) != -1
|
||||
|
||||
@classmethod
|
||||
def preprocess_training_data(cls, gold_tuples, label_freq_cutoff=30):
|
||||
preprocessed = []
|
||||
for heads,labels in labeled_trees:
|
||||
proj_heads,deco_labels = self.projectivize(heads,labels)
|
||||
# set the label to ROOT for each root dependent
|
||||
deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ]
|
||||
preprocessed.append((proj_heads,deco_labels))
|
||||
freqs = Counter()
|
||||
for raw_text, sents in gold_tuples:
|
||||
prepro_sents = []
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
proj_heads,deco_labels = cls.projectivize(heads,labels)
|
||||
# set the label to ROOT for each root dependent
|
||||
deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ]
|
||||
# count label frequencies
|
||||
if label_freq_cutoff > 0:
|
||||
freqs.update( label for label in deco_labels if cls.is_decorated(label) )
|
||||
prepro_sents.append(((ids,words,tags,proj_heads,deco_labels,iob), ctnts))
|
||||
preprocessed.append((raw_text, prepro_sents))
|
||||
|
||||
if label_freq_cutoff > 0:
|
||||
return self._filter_labels(preprocessed,label_freq_cutoff)
|
||||
return cls._filter_labels(preprocessed,label_freq_cutoff,freqs)
|
||||
return preprocessed
|
||||
|
||||
|
||||
def projectivize(self, heads, labels):
|
||||
@classmethod
|
||||
def projectivize(cls, heads, labels):
|
||||
# use the algorithm by Nivre & Nilsson 2005
|
||||
# assumes heads to be a proper tree, i.e. connected and cycle-free
|
||||
# returns a new pair (heads,labels) which encode
|
||||
# a projective and decorated tree
|
||||
proj_heads = copy(heads)
|
||||
smallest_np_arc = self._get_smallest_nonproj_arc(proj_heads)
|
||||
smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads)
|
||||
if smallest_np_arc == None: # this sentence is already projective
|
||||
return proj_heads, copy(labels)
|
||||
while smallest_np_arc != None:
|
||||
self._lift(smallest_np_arc, proj_heads)
|
||||
smallest_np_arc = self._get_smallest_nonproj_arc(proj_heads)
|
||||
deco_labels = self._decorate(heads, proj_heads, labels)
|
||||
cls._lift(smallest_np_arc, proj_heads)
|
||||
smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads)
|
||||
deco_labels = cls._decorate(heads, proj_heads, labels)
|
||||
return proj_heads, deco_labels
|
||||
|
||||
|
||||
def deprojectivize(self, heads, labels):
|
||||
@classmethod
|
||||
def deprojectivize(cls, Doc tokens):
|
||||
# reattach arcs with decorated labels (following HEAD scheme)
|
||||
# for each decorated arc X||Y, search top-down, left-to-right,
|
||||
# breadth-first until hitting a Y then make this the new head
|
||||
newheads, newlabels = copy(heads), copy(labels)
|
||||
spans = None
|
||||
for tokenid, head in enumerate(heads):
|
||||
if labels[tokenid].find('||') != -1:
|
||||
newlabel,_,headlabel = labels[tokenid].partition('||')
|
||||
newhead = self._find_new_head(head,tokenid,headlabel,heads,labels,spans=spans)
|
||||
newheads[tokenid] = newhead
|
||||
newlabels[tokenid] = newlabel
|
||||
return newheads, newlabels
|
||||
parse = tokens.to_array([HEAD, DEP])
|
||||
labels = [ tokens.vocab.strings[int(p[1])] for p in parse ]
|
||||
for token in tokens:
|
||||
if cls.is_decorated(token.dep_):
|
||||
newlabel,headlabel = cls.decompose(token.dep_)
|
||||
newhead = cls._find_new_head(token,headlabel)
|
||||
parse[token.i,1] = tokens.vocab.strings[newlabel]
|
||||
parse[token.i,0] = newhead.i - token.i
|
||||
tokens.from_array([HEAD, DEP],parse)
|
||||
|
||||
|
||||
def _decorate(self, heads, proj_heads, labels):
|
||||
@classmethod
|
||||
def _decorate(cls, heads, proj_heads, labels):
|
||||
# uses decoration scheme HEAD from Nivre & Nilsson 2005
|
||||
assert(len(heads) == len(proj_heads) == len(labels))
|
||||
deco_labels = []
|
||||
for tokenid,head in enumerate(heads):
|
||||
if head != proj_heads[tokenid]:
|
||||
deco_labels.append('%s||%s' % (labels[tokenid],labels[head]))
|
||||
deco_labels.append('%s%s%s' % (labels[tokenid],cls.delimiter,labels[head]))
|
||||
else:
|
||||
deco_labels.append(labels[tokenid])
|
||||
return deco_labels
|
||||
|
||||
|
||||
def _get_smallest_nonproj_arc(self, heads):
|
||||
@classmethod
|
||||
def _get_smallest_nonproj_arc(cls, heads):
|
||||
# return the smallest non-proj arc or None
|
||||
# where size is defined as the distance between dep and head
|
||||
# and ties are broken left to right
|
||||
|
@ -131,7 +156,8 @@ class PseudoProjective:
|
|||
return smallest_np_arc
|
||||
|
||||
|
||||
def _lift(self, tokenid, heads):
|
||||
@classmethod
|
||||
def _lift(cls, tokenid, heads):
|
||||
# reattaches a word to it's grandfather
|
||||
head = heads[tokenid]
|
||||
ghead = heads[head]
|
||||
|
@ -139,43 +165,36 @@ class PseudoProjective:
|
|||
heads[tokenid] = ghead if head != ghead else tokenid
|
||||
|
||||
|
||||
def _find_new_head(self, rootid, tokenid, headlabel, heads, labels, spans=None):
|
||||
@classmethod
|
||||
def _find_new_head(cls, token, headlabel):
|
||||
# search through the tree starting from root
|
||||
# returns the id of the first descendant with the given label
|
||||
# if there is none, return the current head (no change)
|
||||
if not spans:
|
||||
spans = self._make_span_index(heads)
|
||||
queue = spans.get(rootid,[])
|
||||
queue.remove(tokenid) # don't search in the subtree of the nonproj arc
|
||||
queue = [token.head]
|
||||
while queue:
|
||||
next_queue = []
|
||||
for idx in queue:
|
||||
if labels[idx] == headlabel:
|
||||
return idx
|
||||
next_queue.extend(spans.get(idx,[]))
|
||||
for qtoken in queue:
|
||||
for child in qtoken.children:
|
||||
if child == token:
|
||||
continue
|
||||
if child.dep_ == headlabel:
|
||||
return child
|
||||
next_queue.append(child)
|
||||
queue = next_queue
|
||||
return heads[tokenid]
|
||||
return token.head
|
||||
|
||||
|
||||
def _make_span_index(self, heads):
|
||||
# stores the direct dependents for each token
|
||||
# for searching top-down through a tree
|
||||
spans = {}
|
||||
for tokenid, head in enumerate(heads):
|
||||
if tokenid == head: # root
|
||||
continue
|
||||
if head not in spans:
|
||||
spans[head] = []
|
||||
spans[head].append(tokenid)
|
||||
return spans
|
||||
|
||||
|
||||
def _filter_labels(self, labeled_trees, cutoff):
|
||||
@classmethod
|
||||
def _filter_labels(cls, gold_tuples, cutoff, freqs):
|
||||
# throw away infrequent decorated labels
|
||||
# can't learn them reliably anyway and keeps label set smaller
|
||||
freqs = Counter([ label for _,labels in labeled_trees for label in labels if label.find('||') != -1 ])
|
||||
filtered = []
|
||||
for proj_heads,deco_labels in labeled_trees:
|
||||
filtered_labels = [ label.partition('||')[0] if freqs.get(label,cutoff) < cutoff else label for label in deco_labels ]
|
||||
filtered.append((proj_heads,filtered_labels))
|
||||
for raw_text, sents in gold_tuples:
|
||||
filtered_sents = []
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
filtered_labels = [ cls.decompose(label)[0] if freqs.get(label,cutoff) < cutoff else label for label in labels ]
|
||||
filtered_sents.append(((ids,words,tags,heads,filtered_labels,iob), ctnts))
|
||||
filtered.append((raw_text, filtered_sents))
|
||||
return filtered
|
||||
|
||||
|
||||
|
|
|
@ -15,5 +15,6 @@ cdef class ParserModel(AveragedPerceptron):
|
|||
cdef class Parser:
|
||||
cdef readonly ParserModel model
|
||||
cdef readonly TransitionSystem moves
|
||||
cdef int _projectivize
|
||||
|
||||
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil
|
||||
|
|
|
@ -17,6 +17,7 @@ from os import path
|
|||
import shutil
|
||||
import json
|
||||
import sys
|
||||
from .nonproj import PseudoProjectivity
|
||||
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from murmurhash.mrmr cimport hash64
|
||||
|
@ -78,9 +79,10 @@ cdef class ParserModel(AveragedPerceptron):
|
|||
|
||||
|
||||
cdef class Parser:
|
||||
def __init__(self, StringStore strings, transition_system, ParserModel model):
|
||||
def __init__(self, StringStore strings, transition_system, ParserModel model, int projectivize = 0):
|
||||
self.moves = transition_system
|
||||
self.model = model
|
||||
self._projectivize = projectivize
|
||||
|
||||
@classmethod
|
||||
def from_dir(cls, model_dir, strings, transition_system):
|
||||
|
@ -94,7 +96,7 @@ cdef class Parser:
|
|||
model = ParserModel(templates)
|
||||
if path.exists(path.join(model_dir, 'model')):
|
||||
model.load(path.join(model_dir, 'model'))
|
||||
return cls(strings, moves, model)
|
||||
return cls(strings, moves, model, cfg.projectivize)
|
||||
|
||||
@classmethod
|
||||
def load(cls, pkg_or_str_or_file, vocab):
|
||||
|
@ -113,6 +115,9 @@ cdef class Parser:
|
|||
tokens.is_parsed = True
|
||||
# Check for KeyboardInterrupt etc. Untested
|
||||
PyErr_CheckSignals()
|
||||
# projectivize output
|
||||
if self._projectivize:
|
||||
PseudoProjectivity.deprojectivize(tokens)
|
||||
|
||||
def pipe(self, stream, int batch_size=1000, int n_threads=2):
|
||||
cdef Pool mem = Pool()
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
from __future__ import unicode_literals
|
||||
import pytest
|
||||
|
||||
from spacy.nonproj import ancestors, contains_cycle, is_nonproj_arc, is_nonproj_tree, PseudoProjective
|
||||
from spacy.tokens.doc import Doc
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.tokenizer import Tokenizer
|
||||
from spacy.attrs import DEP, HEAD
|
||||
import numpy
|
||||
|
||||
from spacy.syntax.nonproj import ancestors, contains_cycle, is_nonproj_arc, is_nonproj_tree, PseudoProjectivity
|
||||
|
||||
def test_ancestors():
|
||||
tree = [1,2,2,4,5,2,2]
|
||||
|
@ -50,52 +56,53 @@ def test_is_nonproj_tree():
|
|||
assert(is_nonproj_tree(partial_tree) == False)
|
||||
assert(is_nonproj_tree(multirooted_tree) == True)
|
||||
|
||||
def test_pseudoprojective():
|
||||
def test_pseudoprojectivity():
|
||||
tree = [1,2,2]
|
||||
nonproj_tree = [1,2,2,4,5,2,7,4,2]
|
||||
labels = ['NK','SB','ROOT','NK','OA','OC','SB','RC','--']
|
||||
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
|
||||
labels2 = ['MO','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--']
|
||||
|
||||
pp = PseudoProjective()
|
||||
assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
|
||||
assert(PseudoProjectivity.decompose('X') == ('X',''))
|
||||
|
||||
assert(pp._make_span_index(tree) == { 1:[0], 2:[1] })
|
||||
assert(pp._make_span_index(nonproj_tree) == { 1:[0], 2:[1,5,8], 4:[3,7], 5:[4], 7:[6] })
|
||||
assert(PseudoProjectivity.is_decorated('X||Y') == True)
|
||||
assert(PseudoProjectivity.is_decorated('X') == False)
|
||||
|
||||
pp._lift(0,tree)
|
||||
PseudoProjectivity._lift(0,tree)
|
||||
assert(tree == [2,2,2])
|
||||
|
||||
np_arc = pp._get_smallest_nonproj_arc(nonproj_tree)
|
||||
np_arc = PseudoProjectivity._get_smallest_nonproj_arc(nonproj_tree)
|
||||
assert(np_arc == 7)
|
||||
|
||||
np_arc = pp._get_smallest_nonproj_arc(nonproj_tree2)
|
||||
np_arc = PseudoProjectivity._get_smallest_nonproj_arc(nonproj_tree2)
|
||||
assert(np_arc == 10)
|
||||
|
||||
proj_heads, deco_labels = pp.projectivize(nonproj_tree,labels)
|
||||
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
|
||||
assert(proj_heads == [1,2,2,4,5,2,7,5,2])
|
||||
assert(deco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC||OA','--'])
|
||||
deproj_heads, undeco_labels = pp.deprojectivize(proj_heads,deco_labels)
|
||||
assert(deproj_heads == nonproj_tree)
|
||||
assert(undeco_labels == labels)
|
||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
|
||||
# assert(deproj_heads == nonproj_tree)
|
||||
# assert(undeco_labels == labels)
|
||||
|
||||
proj_heads, deco_labels = pp.projectivize(nonproj_tree2,labels2)
|
||||
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
|
||||
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
|
||||
assert(deco_labels == ['MO||OC','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
|
||||
deproj_heads, undeco_labels = pp.deprojectivize(proj_heads,deco_labels)
|
||||
assert(deproj_heads == nonproj_tree2)
|
||||
assert(undeco_labels == labels2)
|
||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
|
||||
# assert(deproj_heads == nonproj_tree2)
|
||||
# assert(undeco_labels == labels2)
|
||||
|
||||
# if decoration is wrong such that there is no head with the desired label
|
||||
# the structure is kept and the label is undecorated
|
||||
deproj_heads, undeco_labels = pp.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--'])
|
||||
assert(deproj_heads == [1,2,2,4,5,2,7,5,2])
|
||||
assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'])
|
||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--'])
|
||||
# assert(deproj_heads == [1,2,2,4,5,2,7,5,2])
|
||||
# assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'])
|
||||
|
||||
# if there are two potential new heads, the first one is chosen even if it's wrong
|
||||
deproj_heads, undeco_labels = pp.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \
|
||||
['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
|
||||
assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
|
||||
assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'])
|
||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \
|
||||
# ['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
|
||||
# assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
|
||||
# assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'])
|
||||
|
||||
|
||||
|
|
@ -201,17 +201,9 @@ cdef class Token:
|
|||
cdef int nr_iter = 0
|
||||
cdef const TokenC* ptr = self.c - (self.i - self.c.l_edge)
|
||||
while ptr < self.c:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
# child.
|
||||
if (ptr.head >= 1) and (ptr + ptr.head) < self.c:
|
||||
ptr += ptr.head
|
||||
|
||||
elif ptr + ptr.head == self.c:
|
||||
if ptr + ptr.head == self.c:
|
||||
yield self.doc[ptr - (self.c - self.i)]
|
||||
ptr += 1
|
||||
else:
|
||||
ptr += 1
|
||||
ptr += 1
|
||||
nr_iter += 1
|
||||
# This is ugly, but it's a way to guard out infinite loops
|
||||
if nr_iter >= 10000000:
|
||||
|
@ -226,16 +218,10 @@ cdef class Token:
|
|||
tokens = []
|
||||
cdef int nr_iter = 0
|
||||
while ptr > self.c:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
# child.
|
||||
if (ptr.head < 0) and ((ptr + ptr.head) > self.c):
|
||||
ptr += ptr.head
|
||||
elif ptr + ptr.head == self.c:
|
||||
if ptr + ptr.head == self.c:
|
||||
tokens.append(self.doc[ptr - (self.c - self.i)])
|
||||
ptr -= 1
|
||||
else:
|
||||
ptr -= 1
|
||||
ptr -= 1
|
||||
nr_iter += 1
|
||||
if nr_iter >= 10000000:
|
||||
raise RuntimeError(
|
||||
"Possibly infinite loop encountered while looking for token.rights")
|
||||
|
|
Loading…
Reference in New Issue
Block a user