mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
341 lines
11 KiB
Python
341 lines
11 KiB
Python
#!/usr/bin/env python
|
|
from __future__ import unicode_literals
|
|
from __future__ import print_function
|
|
import sys
|
|
|
|
import falcon
|
|
import json
|
|
from os import path
|
|
from collections import defaultdict
|
|
import pprint
|
|
import numpy
|
|
|
|
import spacy.en
|
|
from spacy.attrs import ORTH, SPACY, TAG, POS, ENT_IOB, ENT_TYPE
|
|
from spacy.parts_of_speech import NAMES as UNIV_POS_NAMES
|
|
|
|
try:
|
|
unicode
|
|
except NameError:
|
|
unicode = str
|
|
|
|
|
|
NLU = spacy.en.English()
|
|
|
|
|
|
def merge_entities(doc):
|
|
ents = [(e[0].idx, e[len(e)-1].idx + len(e[len(e)-1]), e.label_, e.text)
|
|
for e in doc.ents if len(e) >= 2]
|
|
for start, end, label, lemma in ents:
|
|
merged = doc.merge(start, end, label, text, label)
|
|
assert merged != None
|
|
|
|
|
|
def merge_nps(doc):
|
|
nps = [(np[0].idx, np[-1].idx + len(np[-1]), np.root.tag_, np.text)
|
|
for np in doc.noun_chunks if len(np) >= 2]
|
|
|
|
for start, end, ent_type, lemma in nps:
|
|
doc.merge(start, end, u'NP', lemma, ent_type)
|
|
|
|
|
|
def merge_punct(tokens):
|
|
# Merge punctuation onto its head
|
|
collect = False
|
|
start = None
|
|
merges = []
|
|
|
|
for word in tokens:
|
|
if word.whitespace_:
|
|
if collect:
|
|
span = tokens[start:word.i+1]
|
|
if len(span) >= 2:
|
|
merges.append((
|
|
span[0].idx,
|
|
span[-1].idx + len(span[-1]),
|
|
span.root.tag_,
|
|
span.root.lemma_,
|
|
span.root.ent_type_))
|
|
collect = False
|
|
start = None
|
|
elif not collect:
|
|
collect = True
|
|
start = word.i
|
|
if collect:
|
|
span = tokens[start:len(tokens)]
|
|
merges.append((span[0].idx, span[-1].idx + len(span[-1]),
|
|
span.root.tag_, span.root.lemma_, span.root.ent_type_))
|
|
for merge in merges:
|
|
tokens.merge(*merge)
|
|
|
|
|
|
def get_actions(parse_state, n_actions):
|
|
actions = []
|
|
queue = list(sorted(parse_state.queue))
|
|
stack = list(sorted(parse_state.stack))
|
|
stack = []
|
|
actions.append({'label': 'shift', 'key': 'S', 'binding': 38,
|
|
'is_valid': NLU.parser.moves.is_valid(parse_state, 'S')})
|
|
actions.append({'label': 'left', 'key': 'L', 'binding': 37,
|
|
'is_valid': NLU.parser.moves.is_valid(parse_state, 'L-det')})
|
|
actions.append({'label': 'predict', 'key': '_', 'binding': 32,
|
|
'is_valid': bool(parse_state.queue or parse_state.stack)})
|
|
actions.append({'label': 'right', 'key': 'R', 'binding': 39,
|
|
'is_valid': NLU.parser.moves.is_valid(parse_state, 'R-dobj')})
|
|
actions.append({'label': 'undo', 'key': '-', 'binding': 8,
|
|
'is_valid': n_actions != 0})
|
|
actions.append({'label': 'reduce', 'key': 'D', 'binding': 40,
|
|
'is_valid': NLU.parser.moves.is_valid(parse_state, 'D')})
|
|
return actions
|
|
|
|
|
|
class Model(object):
|
|
def to_json(self):
|
|
return {name: _as_json(value) for name, value in self.__dict__.items()
|
|
if not name.startswith('_')}
|
|
|
|
def _as_json(value):
|
|
if hasattr(value, 'to_json'):
|
|
return value.to_json()
|
|
elif isinstance(value, list):
|
|
return [_as_json(v) for v in value]
|
|
elif isinstance(value, set):
|
|
return {key: True for key in value}
|
|
else:
|
|
return value
|
|
|
|
|
|
def _parse_history(history):
|
|
if history and history.endswith(','):
|
|
history = history[:-1]
|
|
history = history.strip().split(',') if history else tuple()
|
|
new_hist = []
|
|
history_length = len(history)
|
|
for action in history:
|
|
if action == '-':
|
|
if new_hist:
|
|
new_hist.pop()
|
|
else:
|
|
new_hist.append(action)
|
|
return new_hist, history_length
|
|
|
|
|
|
def apply_edits(tokens, word_edits, tag_edits):
|
|
new_words = []
|
|
attrs = (POS, ENT_TYPE, ENT_IOB)
|
|
new_analysis = numpy.zeros(shape=(len(tokens), len(attrs)), dtype=numpy.int32)
|
|
for word in tokens:
|
|
key = str(word.i)
|
|
new_words.append(word_edits.get(key, word.orth_))
|
|
tag = tag_edits.get(key, word.pos_)
|
|
if tag in UNIV_POS_NAMES:
|
|
new_analysis[word.i, 0] = UNIV_POS_NAMES[tag]
|
|
# Set ent_type=0 and IOB="O"
|
|
new_analysis[word.i, 1] = 0
|
|
new_analysis[word.i, 2] = 2
|
|
else:
|
|
new_analysis[word.i, 0] = word.pos
|
|
new_analysis[word.i, 1] = NLU.vocab.strings[tag]
|
|
new_analysis[word.i, 2] = 3
|
|
|
|
doc = NLU.tokenizer.tokens_from_list(new_words)
|
|
doc.from_array(attrs, new_analysis)
|
|
NLU.parser(doc)
|
|
return doc
|
|
|
|
|
|
class Parse(Model):
|
|
def __init__(self, doc, states, actions, **kwargs):
|
|
word_edits = kwargs.get('words', {})
|
|
tag_edits = kwargs.get('tags', {})
|
|
if word_edits or tag_edits:
|
|
doc = apply_edits(doc, word_edits, tag_edits)
|
|
notes = kwargs.get('notes', {})
|
|
self.actions = actions
|
|
self.words = [Word(w, w.i in word_edits, w.i in tag_edits) for w in doc]
|
|
self.states = states
|
|
self.notes = notes
|
|
for word in doc:
|
|
print(word.orth_, word.head.orth_)
|
|
|
|
@classmethod
|
|
def from_text(cls, text, **kwargs):
|
|
tokens = NLU(text)
|
|
#merge_entities(tokens)
|
|
merge_nps(tokens)
|
|
#merge_punct(tokens)
|
|
return cls(tokens, [State.from_doc(tokens)], [], **kwargs)
|
|
|
|
@classmethod
|
|
def from_history(cls, text, history, **kwargs):
|
|
if not isinstance(text, unicode):
|
|
text = text.decode('utf8')
|
|
text = text.replace('-SLASH-', '/')
|
|
history, history_length = _parse_history(history)
|
|
|
|
tokens = NLU.tokenizer(text)
|
|
NLU.tagger(tokens)
|
|
NLU.matcher(tokens)
|
|
|
|
with NLU.parser.step_through(tokens) as state:
|
|
for action in history:
|
|
state.transition(action)
|
|
|
|
NLU.entity(tokens)
|
|
actions = get_actions(state.stcls, len(history))
|
|
return Parse(tokens, [State(state.heads, state.deps, state.stack, state.queue)],
|
|
actions, **kwargs)
|
|
|
|
@classmethod
|
|
def with_history(cls, text):
|
|
tokens = NLU.tokenizer(text)
|
|
NLU.tagger(tokens)
|
|
NLU.matcher(tokens)
|
|
|
|
with NLU.parser.step_through(tokens) as state:
|
|
states = []
|
|
while not state.is_final:
|
|
action = state.predict()
|
|
state.transition(action)
|
|
states.append(State(state.heads, state.deps, state.stack, state.queue))
|
|
actions = [
|
|
{'label': 'prev', 'key': 'P', 'binding': 37, 'is_valid': True},
|
|
{'label': 'next', 'key': 'N', 'binding': 39, 'is_valid': True}
|
|
]
|
|
return Parse(state.doc, states, actions)
|
|
|
|
|
|
class Word(Model):
|
|
def __init__(self, token, is_w_edit=False, is_t_edit=False):
|
|
self.word = token.orth_
|
|
self.tag = token.pos_
|
|
self.tag = token.pos_ if not token.ent_type_ else token.ent_type_
|
|
self.is_entity = token.ent_iob in (1, 3)
|
|
self.is_w_edit = is_w_edit
|
|
self.is_t_edit = is_t_edit
|
|
self.prob = token.prob
|
|
|
|
|
|
class State(Model):
|
|
def __init__(self, heads, deps, stack, queue):
|
|
Model.__init__(self)
|
|
|
|
queue = [w for w in queue if w >= 0]
|
|
self.focus = min(queue) if queue else -1
|
|
self.is_final = bool(not stack and not queue)
|
|
self.stack = set(stack)
|
|
self.arrows = self._get_arrows(heads, deps)
|
|
|
|
@classmethod
|
|
def from_doc(cls, doc):
|
|
return cls([w.head.i for w in doc], [w.dep_ for w in doc], [], [])
|
|
|
|
def _get_arrows(self, heads, deps):
|
|
arcs = defaultdict(dict)
|
|
for i, (head, dep) in enumerate(zip(heads, deps)):
|
|
if i < head:
|
|
arcs[head - i][i] = Arrow(i, head, dep)
|
|
elif i > head:
|
|
arcs[i - head][head] = Arrow(i, head, dep)
|
|
output = []
|
|
for level in range(1, len(heads)):
|
|
level_arcs = []
|
|
for i in range(len(heads) - level):
|
|
level_arcs.append(arcs[level].get(i))
|
|
output.append(level_arcs)
|
|
while output and all(arc is None for arc in output[-1]):
|
|
output.pop()
|
|
return output
|
|
|
|
|
|
class Arrow(Model):
|
|
def __init__(self, word, head, label):
|
|
self.dir = 'left' if head > word else 'right'
|
|
self.label = label
|
|
|
|
|
|
class Endpoint(object):
|
|
def set_header(self, resp):
|
|
resp.content_type = 'text/string'
|
|
resp.append_header('Access-Control-Allow-Origin', "*")
|
|
resp.status = falcon.HTTP_200
|
|
|
|
def set_body(self, resp, parse):
|
|
resp.body = json.dumps(parse.to_json(), indent=4)
|
|
|
|
def on_get(self, req, resp, text):
|
|
if not isinstance(text, unicode):
|
|
text = text.decode('utf8')
|
|
self.set_body(resp, self.get_parse(text))
|
|
self.set_header(resp)
|
|
|
|
def on_post(self, req, resp):
|
|
try:
|
|
body_bytes = req.stream.read()
|
|
json_data = json.loads(body_bytes.decode('utf8'))
|
|
text = json_data['text']
|
|
if not isinstance(text, unicode):
|
|
text = text.decode('utf8')
|
|
self.set_body(resp, self.get_parse(text))
|
|
self.set_header(resp)
|
|
except:
|
|
pass
|
|
|
|
|
|
class ParseEP(Endpoint):
|
|
def get_parse(self, text, **kwargs):
|
|
return Parse.from_text(text, **kwargs)
|
|
|
|
|
|
class StepsEP(Endpoint):
|
|
def get_parse(self, text):
|
|
print('Step=', repr(text))
|
|
return Parse.with_history(text)
|
|
|
|
|
|
class ManualEP(Endpoint):
|
|
def get_parse(self, text, **kwargs):
|
|
print('Manual=', repr(text))
|
|
if '/' in text:
|
|
text, actions = text.rsplit('/', 1)
|
|
else:
|
|
actions = ''
|
|
return Parse.from_history(text, actions, **kwargs)
|
|
|
|
def on_get(self, req, resp, text, actions=''):
|
|
if not isinstance(text, unicode):
|
|
text = text.decode('utf8')
|
|
self.set_body(resp, self.get_parse(text + '/' + actions))
|
|
self.set_header(resp)
|
|
|
|
def on_post(self, req, resp):
|
|
self.set_header(resp)
|
|
body_bytes = req.stream.read()
|
|
json_data = json.loads(body_bytes.decode('utf8'))
|
|
print(json_data)
|
|
params = json_data.get('params', {})
|
|
self.set_body(resp, self.get_parse(json_data['text'], **params))
|
|
|
|
|
|
app = falcon.API()
|
|
|
|
remote_man = ManualEP()
|
|
remote_parse = ParseEP()
|
|
remote_steps = StepsEP()
|
|
|
|
app.add_route('/api/displacy/parse/', remote_parse)
|
|
app.add_route('/api/displacy/parse/{text}/', remote_parse)
|
|
|
|
app.add_route('/api/displacy/steps/', remote_steps)
|
|
app.add_route('/api/displacy/steps/{text}/', remote_steps)
|
|
|
|
app.add_route('/api/displacy/manual/', remote_man)
|
|
app.add_route('/api/displacy/manual/{text}/', remote_man)
|
|
app.add_route('/api/displacy/manual/{text}/{actions}', remote_man)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
text, actions = open(sys.argv[1]).read().strip().split('\n')
|
|
parse = Parse.from_text(text)
|
|
pprint.pprint(parse.to_json())
|