2017-03-15 17:29:42 +03:00
|
|
|
# cython: profile=True
|
2017-04-15 13:05:47 +03:00
|
|
|
# coding: utf8
|
2016-10-09 13:24:24 +03:00
|
|
|
from __future__ import unicode_literals, print_function
|
|
|
|
|
2015-05-24 22:50:48 +03:00
|
|
|
import re
|
2017-05-21 17:06:17 +03:00
|
|
|
import random
|
2017-05-26 01:15:09 +03:00
|
|
|
import cytoolz
|
2017-09-14 14:37:41 +03:00
|
|
|
import itertools
|
2018-03-27 20:23:02 +03:00
|
|
|
import numpy
|
|
|
|
import tempfile
|
|
|
|
import shutil
|
|
|
|
from pathlib import Path
|
|
|
|
import msgpack
|
|
|
|
|
|
|
|
import ujson
|
2015-07-25 19:11:36 +03:00
|
|
|
|
2018-03-27 20:23:02 +03:00
|
|
|
from . import _align
|
2016-03-01 12:09:08 +03:00
|
|
|
from .syntax import nonproj
|
2017-05-21 17:06:17 +03:00
|
|
|
from .tokens import Doc
|
2018-04-03 16:50:31 +03:00
|
|
|
from .errors import Errors
|
2017-10-27 22:07:59 +03:00
|
|
|
from . import util
|
2018-03-27 20:23:02 +03:00
|
|
|
from .util import minibatch, itershuffle
|
|
|
|
from .compat import json_dumps
|
2016-02-22 16:40:40 +03:00
|
|
|
|
2018-03-27 20:23:02 +03:00
|
|
|
from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
|
2015-02-21 19:06:58 +03:00
|
|
|
|
2015-06-08 01:54:13 +03:00
|
|
|
def tags_to_entities(tags):
|
|
|
|
entities = []
|
|
|
|
start = None
|
|
|
|
for i, tag in enumerate(tags):
|
2016-11-25 17:57:59 +03:00
|
|
|
if tag is None:
|
|
|
|
continue
|
2015-06-08 01:54:13 +03:00
|
|
|
if tag.startswith('O'):
|
|
|
|
# TODO: We shouldn't be getting these malformed inputs. Fix this.
|
|
|
|
if start is not None:
|
|
|
|
start = None
|
|
|
|
continue
|
|
|
|
elif tag == '-':
|
|
|
|
continue
|
|
|
|
elif tag.startswith('I'):
|
2018-04-03 16:50:31 +03:00
|
|
|
if start is None:
|
|
|
|
raise ValueError(Errors.E067.format(tags=tags[:i]))
|
2015-06-08 01:54:13 +03:00
|
|
|
continue
|
|
|
|
if tag.startswith('U'):
|
|
|
|
entities.append((tag[2:], i, i))
|
|
|
|
elif tag.startswith('B'):
|
|
|
|
start = i
|
|
|
|
elif tag.startswith('L'):
|
|
|
|
entities.append((tag[2:], start, i))
|
|
|
|
start = None
|
|
|
|
else:
|
2018-04-03 16:50:31 +03:00
|
|
|
raise ValueError(Errors.E068.format(tag=tag))
|
2015-06-08 01:54:13 +03:00
|
|
|
return entities
|
|
|
|
|
|
|
|
|
2016-10-13 04:24:29 +03:00
|
|
|
def merge_sents(sents):
|
|
|
|
m_deps = [[], [], [], [], [], []]
|
|
|
|
m_brackets = []
|
|
|
|
i = 0
|
|
|
|
for (ids, words, tags, heads, labels, ner), brackets in sents:
|
|
|
|
m_deps[0].extend(id_ + i for id_ in ids)
|
|
|
|
m_deps[1].extend(words)
|
|
|
|
m_deps[2].extend(tags)
|
|
|
|
m_deps[3].extend(head + i for head in heads)
|
|
|
|
m_deps[4].extend(labels)
|
|
|
|
m_deps[5].extend(ner)
|
2017-10-27 18:02:55 +03:00
|
|
|
m_brackets.extend((b['first'] + i, b['last'] + i, b['label'])
|
|
|
|
for b in brackets)
|
2016-10-13 04:24:29 +03:00
|
|
|
i += len(ids)
|
|
|
|
return [(m_deps, m_brackets)]
|
|
|
|
|
2015-06-08 01:54:13 +03:00
|
|
|
|
2015-05-24 22:50:48 +03:00
|
|
|
punct_re = re.compile(r'\W')
|
2018-03-27 20:23:02 +03:00
|
|
|
def align(cand_words, gold_words):
|
2015-05-27 20:13:11 +03:00
|
|
|
if cand_words == gold_words:
|
2018-03-27 20:23:02 +03:00
|
|
|
alignment = numpy.arange(len(cand_words))
|
|
|
|
return 0, alignment, alignment, {}, {}
|
|
|
|
cand_words = [w.replace(' ', '') for w in cand_words]
|
|
|
|
gold_words = [w.replace(' ', '') for w in gold_words]
|
|
|
|
cost, i2j, j2i, matrix = _align.align(cand_words, gold_words)
|
|
|
|
i2j_multi, j2i_multi = _align.multi_align(i2j, j2i, [len(w) for w in cand_words],
|
|
|
|
[len(w) for w in gold_words])
|
|
|
|
for i, j in list(i2j_multi.items()):
|
|
|
|
if i2j_multi.get(i+1) != j and i2j_multi.get(i-1) != j:
|
|
|
|
i2j[i] = j
|
|
|
|
i2j_multi.pop(i)
|
|
|
|
for j, i in list(j2i_multi.items()):
|
|
|
|
if j2i_multi.get(j+1) != i and j2i_multi.get(j-1) != i:
|
|
|
|
j2i[j] = i
|
|
|
|
j2i_multi.pop(j)
|
|
|
|
return cost, i2j, j2i, i2j_multi, j2i_multi
|
2015-05-24 22:50:48 +03:00
|
|
|
|
2015-05-27 20:13:11 +03:00
|
|
|
|
2017-05-21 17:06:17 +03:00
|
|
|
class GoldCorpus(object):
|
2017-05-22 13:29:30 +03:00
|
|
|
"""An annotated corpus, using the JSON file format. Manages
|
|
|
|
annotations for tagging, dependency parsing and NER."""
|
2018-03-27 20:23:02 +03:00
|
|
|
def __init__(self, train, dev, gold_preproc=False, limit=None):
|
2017-05-22 13:29:30 +03:00
|
|
|
"""Create a GoldCorpus.
|
|
|
|
|
|
|
|
train_path (unicode or Path): File or directory of training data.
|
|
|
|
dev_path (unicode or Path): File or directory of development data.
|
2017-10-27 18:02:55 +03:00
|
|
|
RETURNS (GoldCorpus): The newly created object.
|
2017-05-22 13:29:30 +03:00
|
|
|
"""
|
2017-05-22 18:40:46 +03:00
|
|
|
self.limit = limit
|
2018-03-27 20:23:02 +03:00
|
|
|
if isinstance(train, str) or isinstance(train, Path):
|
|
|
|
train = self.read_tuples(self.walk_corpus(train))
|
|
|
|
dev = self.read_tuples(self.walk_corpus(dev))
|
2017-05-21 17:06:17 +03:00
|
|
|
|
2018-03-27 20:23:02 +03:00
|
|
|
# Write temp directory with one doc per file, so we can shuffle
|
|
|
|
# and stream
|
|
|
|
self.tmp_dir = Path(tempfile.mkdtemp())
|
|
|
|
self.write_msgpack(self.tmp_dir / 'train', train)
|
|
|
|
self.write_msgpack(self.tmp_dir / 'dev', dev)
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
shutil.rmtree(self.tmp_dir)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def write_msgpack(directory, doc_tuples):
|
|
|
|
if not directory.exists():
|
|
|
|
directory.mkdir()
|
|
|
|
for i, doc_tuple in enumerate(doc_tuples):
|
|
|
|
with open(directory / '{}.msg'.format(i), 'wb') as file_:
|
|
|
|
msgpack.dump([doc_tuple], file_, use_bin_type=True, encoding='utf8')
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def walk_corpus(path):
|
|
|
|
path = util.ensure_path(path)
|
|
|
|
if not path.is_dir():
|
|
|
|
return [path]
|
|
|
|
paths = [path]
|
|
|
|
locs = []
|
|
|
|
seen = set()
|
|
|
|
for path in paths:
|
|
|
|
if str(path) in seen:
|
|
|
|
continue
|
|
|
|
seen.add(str(path))
|
|
|
|
if path.parts[-1].startswith('.'):
|
|
|
|
continue
|
|
|
|
elif path.is_dir():
|
|
|
|
paths.extend(path.iterdir())
|
|
|
|
elif path.parts[-1].endswith('.json'):
|
|
|
|
locs.append(path)
|
|
|
|
return locs
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def read_tuples(locs, limit=0):
|
2017-05-22 18:40:46 +03:00
|
|
|
i = 0
|
2018-03-27 20:23:02 +03:00
|
|
|
for loc in locs:
|
|
|
|
loc = util.ensure_path(loc)
|
|
|
|
if loc.parts[-1].endswith('json'):
|
|
|
|
gold_tuples = read_json_file(loc)
|
|
|
|
elif loc.parts[-1].endswith('msg'):
|
|
|
|
with loc.open('rb') as file_:
|
|
|
|
gold_tuples = msgpack.load(file_, encoding='utf8')
|
|
|
|
else:
|
|
|
|
msg = "Cannot read from file: %s. Supported formats: .json, .msg"
|
|
|
|
raise ValueError(msg % loc)
|
2017-05-22 18:40:46 +03:00
|
|
|
for item in gold_tuples:
|
|
|
|
yield item
|
2017-06-05 04:18:20 +03:00
|
|
|
i += len(item[1])
|
2018-03-27 20:23:02 +03:00
|
|
|
if limit and i >= limit:
|
2017-05-22 18:40:46 +03:00
|
|
|
break
|
2017-05-21 17:06:17 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def dev_tuples(self):
|
2018-03-27 20:23:02 +03:00
|
|
|
locs = (self.tmp_dir / 'dev').iterdir()
|
|
|
|
yield from self.read_tuples(locs, limit=self.limit)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def train_tuples(self):
|
|
|
|
locs = (self.tmp_dir / 'train').iterdir()
|
|
|
|
yield from self.read_tuples(locs, limit=self.limit)
|
2017-05-21 17:06:17 +03:00
|
|
|
|
|
|
|
def count_train(self):
|
|
|
|
n = 0
|
2017-06-05 04:18:20 +03:00
|
|
|
i = 0
|
2017-06-03 22:33:39 +03:00
|
|
|
for raw_text, paragraph_tuples in self.train_tuples:
|
2018-03-27 20:23:02 +03:00
|
|
|
for sent_tuples, brackets in paragraph_tuples:
|
|
|
|
n += len(sent_tuples[1])
|
2017-06-05 04:18:20 +03:00
|
|
|
if self.limit and i >= self.limit:
|
|
|
|
break
|
|
|
|
i += len(paragraph_tuples)
|
2017-05-21 17:06:17 +03:00
|
|
|
return n
|
|
|
|
|
2018-03-27 20:23:02 +03:00
|
|
|
def train_docs(self, nlp, gold_preproc=False, max_length=None,
|
|
|
|
noise_level=0.0):
|
|
|
|
locs = list((self.tmp_dir / 'train').iterdir())
|
|
|
|
random.shuffle(locs)
|
|
|
|
train_tuples = self.read_tuples(locs, limit=self.limit)
|
2017-05-26 19:30:52 +03:00
|
|
|
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
|
2017-06-05 04:16:57 +03:00
|
|
|
max_length=max_length,
|
2018-03-27 20:23:02 +03:00
|
|
|
noise_level=noise_level,
|
|
|
|
make_projective=True)
|
2017-05-21 17:06:17 +03:00
|
|
|
yield from gold_docs
|
|
|
|
|
2017-05-23 11:06:53 +03:00
|
|
|
def dev_docs(self, nlp, gold_preproc=False):
|
2018-03-27 20:23:02 +03:00
|
|
|
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples,
|
|
|
|
gold_preproc=gold_preproc)
|
2017-05-22 01:50:49 +03:00
|
|
|
yield from gold_docs
|
2017-05-21 17:06:17 +03:00
|
|
|
|
|
|
|
@classmethod
|
2017-06-05 04:16:57 +03:00
|
|
|
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
|
2018-03-27 20:23:02 +03:00
|
|
|
noise_level=0.0, make_projective=False):
|
2017-05-21 17:06:17 +03:00
|
|
|
for raw_text, paragraph_tuples in tuples:
|
2017-05-23 11:06:53 +03:00
|
|
|
if gold_preproc:
|
|
|
|
raw_text = None
|
|
|
|
else:
|
|
|
|
paragraph_tuples = merge_sents(paragraph_tuples)
|
2017-05-22 01:50:49 +03:00
|
|
|
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
|
2017-06-05 04:16:57 +03:00
|
|
|
gold_preproc, noise_level=noise_level)
|
2018-03-27 20:23:02 +03:00
|
|
|
golds = cls._make_golds(docs, paragraph_tuples, make_projective)
|
2017-05-21 17:06:17 +03:00
|
|
|
for doc, gold in zip(docs, golds):
|
2017-06-03 21:28:52 +03:00
|
|
|
if (not max_length) or len(doc) < max_length:
|
2017-05-26 19:30:52 +03:00
|
|
|
yield doc, gold
|
2017-05-21 17:06:17 +03:00
|
|
|
|
|
|
|
@classmethod
|
2017-06-05 04:16:57 +03:00
|
|
|
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc,
|
|
|
|
noise_level=0.0):
|
2017-05-23 11:06:53 +03:00
|
|
|
if raw_text is not None:
|
2017-06-05 04:16:57 +03:00
|
|
|
raw_text = add_noise(raw_text, noise_level)
|
2017-05-21 17:06:17 +03:00
|
|
|
return [nlp.make_doc(raw_text)]
|
|
|
|
else:
|
2017-10-27 18:02:55 +03:00
|
|
|
return [Doc(nlp.vocab,
|
|
|
|
words=add_noise(sent_tuples[1], noise_level))
|
|
|
|
for (sent_tuples, brackets) in paragraph_tuples]
|
2017-05-21 17:06:17 +03:00
|
|
|
|
|
|
|
@classmethod
|
2018-03-27 20:23:02 +03:00
|
|
|
def _make_golds(cls, docs, paragraph_tuples, make_projective):
|
2018-04-03 16:50:31 +03:00
|
|
|
if len(docs) != len(paragraph_tuples):
|
|
|
|
raise ValueError(Errors.E070.format(n_docs=len(docs),
|
|
|
|
n_annots=len(paragraph_tuples)))
|
2017-05-21 17:06:17 +03:00
|
|
|
if len(docs) == 1:
|
2017-10-27 18:02:55 +03:00
|
|
|
return [GoldParse.from_annot_tuples(docs[0],
|
2018-03-27 20:23:02 +03:00
|
|
|
paragraph_tuples[0][0],
|
|
|
|
make_projective=make_projective)]
|
2017-05-21 17:06:17 +03:00
|
|
|
else:
|
2018-03-27 20:23:02 +03:00
|
|
|
return [GoldParse.from_annot_tuples(doc, sent_tuples,
|
|
|
|
make_projective=make_projective)
|
2017-10-27 18:02:55 +03:00
|
|
|
for doc, (sent_tuples, brackets)
|
|
|
|
in zip(docs, paragraph_tuples)]
|
2017-05-21 17:06:17 +03:00
|
|
|
|
|
|
|
|
2017-06-05 04:16:57 +03:00
|
|
|
def add_noise(orig, noise_level):
|
|
|
|
if random.random() >= noise_level:
|
|
|
|
return orig
|
|
|
|
elif type(orig) == list:
|
|
|
|
corrupted = [_corrupt(word, noise_level) for word in orig]
|
|
|
|
corrupted = [w for w in corrupted if w]
|
|
|
|
return corrupted
|
|
|
|
else:
|
|
|
|
return ''.join(_corrupt(c, noise_level) for c in orig)
|
|
|
|
|
|
|
|
|
|
|
|
def _corrupt(c, noise_level):
|
|
|
|
if random.random() >= noise_level:
|
|
|
|
return c
|
|
|
|
elif c == ' ':
|
|
|
|
return '\n'
|
|
|
|
elif c == '\n':
|
|
|
|
return ' '
|
|
|
|
elif c in ['.', "'", "!", "?"]:
|
|
|
|
return ''
|
|
|
|
else:
|
|
|
|
return c.lower()
|
|
|
|
|
|
|
|
|
2017-05-22 12:48:02 +03:00
|
|
|
def read_json_file(loc, docs_filter=None, limit=None):
|
2017-10-27 22:07:59 +03:00
|
|
|
loc = util.ensure_path(loc)
|
2017-04-15 13:13:00 +03:00
|
|
|
if loc.is_dir():
|
|
|
|
for filename in loc.iterdir():
|
2017-05-17 13:04:50 +03:00
|
|
|
yield from read_json_file(loc / filename, limit=limit)
|
2015-05-29 04:52:55 +03:00
|
|
|
else:
|
2018-03-27 20:23:02 +03:00
|
|
|
for doc in _json_iterate(loc):
|
2015-06-12 03:42:08 +03:00
|
|
|
if docs_filter is not None and not docs_filter(doc):
|
|
|
|
continue
|
2015-05-30 18:54:52 +03:00
|
|
|
paragraphs = []
|
|
|
|
for paragraph in doc['paragraphs']:
|
|
|
|
sents = []
|
|
|
|
for sent in paragraph['sentences']:
|
|
|
|
words = []
|
|
|
|
ids = []
|
|
|
|
tags = []
|
|
|
|
heads = []
|
|
|
|
labels = []
|
|
|
|
ner = []
|
|
|
|
for i, token in enumerate(sent['tokens']):
|
|
|
|
words.append(token['orth'])
|
|
|
|
ids.append(i)
|
2017-10-27 18:02:55 +03:00
|
|
|
tags.append(token.get('tag', '-'))
|
|
|
|
heads.append(token.get('head', 0) + i)
|
|
|
|
labels.append(token.get('dep', ''))
|
2015-06-23 05:14:03 +03:00
|
|
|
# Ensure ROOT label is case-insensitive
|
|
|
|
if labels[-1].lower() == 'root':
|
|
|
|
labels[-1] = 'ROOT'
|
2015-05-30 18:54:52 +03:00
|
|
|
ner.append(token.get('ner', '-'))
|
2017-05-17 13:04:50 +03:00
|
|
|
sents.append([
|
|
|
|
[ids, words, tags, heads, labels, ner],
|
2017-10-27 18:02:55 +03:00
|
|
|
sent.get('brackets', [])])
|
2015-05-30 18:54:52 +03:00
|
|
|
if sents:
|
2017-05-17 13:04:50 +03:00
|
|
|
yield [paragraph.get('raw', None), sents]
|
2015-05-06 17:27:31 +03:00
|
|
|
|
|
|
|
|
2018-03-27 20:23:02 +03:00
|
|
|
def _json_iterate(loc):
|
|
|
|
# We should've made these files jsonl...But since we didn't, parse out
|
|
|
|
# the docs one-by-one to reduce memory usage.
|
|
|
|
# It's okay to read in the whole file -- just don't parse it into JSON.
|
|
|
|
cdef bytes py_raw
|
|
|
|
loc = util.ensure_path(loc)
|
|
|
|
with loc.open('rb') as file_:
|
|
|
|
py_raw = file_.read()
|
|
|
|
raw = <char*>py_raw
|
|
|
|
cdef int square_depth = 0
|
|
|
|
cdef int curly_depth = 0
|
|
|
|
cdef int inside_string = 0
|
|
|
|
cdef int escape = 0
|
|
|
|
cdef int start = -1
|
|
|
|
cdef char c
|
|
|
|
cdef char quote = ord('"')
|
|
|
|
cdef char backslash = ord('\\')
|
|
|
|
cdef char open_square = ord('[')
|
|
|
|
cdef char close_square = ord(']')
|
|
|
|
cdef char open_curly = ord('{')
|
|
|
|
cdef char close_curly = ord('}')
|
|
|
|
for i in range(len(py_raw)):
|
|
|
|
c = raw[i]
|
|
|
|
if c == backslash:
|
|
|
|
escape = True
|
|
|
|
continue
|
|
|
|
if escape:
|
|
|
|
escape = False
|
|
|
|
continue
|
|
|
|
if c == quote:
|
|
|
|
inside_string = not inside_string
|
|
|
|
continue
|
|
|
|
if inside_string:
|
|
|
|
continue
|
|
|
|
if c == open_square:
|
|
|
|
square_depth += 1
|
|
|
|
elif c == close_square:
|
|
|
|
square_depth -= 1
|
|
|
|
elif c == open_curly:
|
|
|
|
if square_depth == 1 and curly_depth == 0:
|
|
|
|
start = i
|
|
|
|
curly_depth += 1
|
|
|
|
elif c == close_curly:
|
|
|
|
curly_depth -= 1
|
|
|
|
if square_depth == 1 and curly_depth == 0:
|
|
|
|
py_str = py_raw[start : i+1].decode('utf8')
|
|
|
|
yield ujson.loads(py_str)
|
|
|
|
start = -1
|
|
|
|
|
|
|
|
|
2017-05-26 19:32:55 +03:00
|
|
|
def iob_to_biluo(tags):
|
2015-04-10 05:59:11 +03:00
|
|
|
out = []
|
|
|
|
curr_label = None
|
|
|
|
tags = list(tags)
|
|
|
|
while tags:
|
|
|
|
out.extend(_consume_os(tags))
|
|
|
|
out.extend(_consume_ent(tags))
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def _consume_os(tags):
|
|
|
|
while tags and tags[0] == 'O':
|
|
|
|
yield tags.pop(0)
|
|
|
|
|
|
|
|
|
|
|
|
def _consume_ent(tags):
|
|
|
|
if not tags:
|
|
|
|
return []
|
|
|
|
target = tags.pop(0).replace('B', 'I')
|
|
|
|
length = 1
|
|
|
|
while tags and tags[0] == target:
|
|
|
|
length += 1
|
|
|
|
tags.pop(0)
|
|
|
|
label = target[2:]
|
|
|
|
if length == 1:
|
|
|
|
return ['U-' + label]
|
|
|
|
else:
|
|
|
|
start = 'B-' + label
|
|
|
|
end = 'L-' + label
|
|
|
|
middle = ['I-%s' % label for _ in range(1, length - 1)]
|
|
|
|
return [start] + middle + [end]
|
|
|
|
|
|
|
|
|
2015-03-09 08:46:22 +03:00
|
|
|
cdef class GoldParse:
|
2016-11-01 14:25:36 +03:00
|
|
|
"""Collection for training annotations."""
|
2016-10-15 23:09:52 +03:00
|
|
|
@classmethod
|
2016-10-16 00:53:29 +03:00
|
|
|
def from_annot_tuples(cls, doc, annot_tuples, make_projective=False):
|
|
|
|
_, words, tags, heads, deps, entities = annot_tuples
|
2017-10-27 18:02:55 +03:00
|
|
|
return cls(doc, words=words, tags=tags, heads=heads, deps=deps,
|
|
|
|
entities=entities, make_projective=make_projective)
|
2016-10-16 00:53:29 +03:00
|
|
|
|
2017-10-27 18:02:55 +03:00
|
|
|
def __init__(self, doc, annot_tuples=None, words=None, tags=None,
|
|
|
|
heads=None, deps=None, entities=None, make_projective=False,
|
2018-05-08 14:48:32 +03:00
|
|
|
cats=None, **_):
|
2017-05-21 14:53:46 +03:00
|
|
|
"""Create a GoldParse.
|
|
|
|
|
|
|
|
doc (Doc): The document the annotations refer to.
|
|
|
|
words (iterable): A sequence of unicode word strings.
|
|
|
|
tags (iterable): A sequence of strings, representing tag annotations.
|
2017-10-27 18:02:55 +03:00
|
|
|
heads (iterable): A sequence of integers, representing syntactic
|
|
|
|
head offsets.
|
|
|
|
deps (iterable): A sequence of strings, representing the syntactic
|
|
|
|
relation types.
|
2017-05-21 14:53:46 +03:00
|
|
|
entities (iterable): A sequence of named entity annotations, either as
|
|
|
|
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
|
|
|
|
representing the entity positions.
|
2017-10-06 02:43:02 +03:00
|
|
|
cats (dict): Labels for text classification. Each key in the dictionary
|
|
|
|
may be a string or an int, or a `(start_char, end_char, label)`
|
2017-07-20 01:17:47 +03:00
|
|
|
tuple, indicating that the label is applied to only part of the
|
|
|
|
document (usually a sentence). Unlike entity annotations, label
|
|
|
|
annotations can overlap, i.e. a single word can be covered by
|
2017-10-06 02:43:02 +03:00
|
|
|
multiple labelled spans. The TextCategorizer component expects
|
2017-10-27 18:02:55 +03:00
|
|
|
true examples of a label to have the value 1.0, and negative
|
|
|
|
examples of a label to have the value 0.0. Labels not in the
|
|
|
|
dictionary are treated as missing - the gradient for those labels
|
|
|
|
will be zero.
|
2017-05-21 14:53:46 +03:00
|
|
|
RETURNS (GoldParse): The newly constructed object.
|
2016-11-01 14:25:36 +03:00
|
|
|
"""
|
2016-10-15 23:09:52 +03:00
|
|
|
if words is None:
|
|
|
|
words = [token.text for token in doc]
|
|
|
|
if tags is None:
|
|
|
|
tags = [None for _ in doc]
|
|
|
|
if heads is None:
|
2017-08-20 15:41:38 +03:00
|
|
|
heads = [None for token in doc]
|
2016-10-15 23:09:52 +03:00
|
|
|
if deps is None:
|
|
|
|
deps = [None for _ in doc]
|
|
|
|
if entities is None:
|
2017-08-20 15:41:38 +03:00
|
|
|
entities = [None for _ in doc]
|
2016-10-15 23:09:52 +03:00
|
|
|
elif len(entities) == 0:
|
|
|
|
entities = ['O' for _ in doc]
|
|
|
|
elif not isinstance(entities[0], basestring):
|
|
|
|
# Assume we have entities specified by character offset.
|
|
|
|
entities = biluo_tags_from_offsets(doc, entities)
|
|
|
|
|
2015-03-09 08:46:22 +03:00
|
|
|
self.mem = Pool()
|
|
|
|
self.loss = 0
|
2016-10-16 00:55:07 +03:00
|
|
|
self.length = len(doc)
|
2015-03-09 08:46:22 +03:00
|
|
|
|
2015-03-09 14:06:01 +03:00
|
|
|
# These are filled by the tagger/parser/entity recogniser
|
2016-10-16 00:53:29 +03:00
|
|
|
self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
|
|
|
|
self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
|
2017-05-28 15:06:40 +03:00
|
|
|
self.c.labels = <attr_t*>self.mem.alloc(len(doc), sizeof(attr_t))
|
2017-05-30 21:37:24 +03:00
|
|
|
self.c.has_dep = <int*>self.mem.alloc(len(doc), sizeof(int))
|
2017-08-26 04:03:14 +03:00
|
|
|
self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int))
|
2016-10-16 00:53:29 +03:00
|
|
|
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
2015-03-09 08:46:22 +03:00
|
|
|
|
2017-10-06 02:43:02 +03:00
|
|
|
self.cats = {} if cats is None else dict(cats)
|
2016-11-25 17:57:59 +03:00
|
|
|
self.words = [None] * len(doc)
|
2016-10-16 00:53:29 +03:00
|
|
|
self.tags = [None] * len(doc)
|
|
|
|
self.heads = [None] * len(doc)
|
2017-03-16 17:38:28 +03:00
|
|
|
self.labels = [None] * len(doc)
|
|
|
|
self.ner = [None] * len(doc)
|
2015-05-24 03:49:56 +03:00
|
|
|
|
2018-03-27 20:23:02 +03:00
|
|
|
# This needs to be done before we align the words
|
|
|
|
if make_projective and heads is not None and deps is not None:
|
|
|
|
heads, deps = nonproj.projectivize(heads, deps)
|
|
|
|
|
|
|
|
# Do many-to-one alignment for misaligned tokens.
|
|
|
|
# If we over-segment, we'll have one gold word that covers a sequence
|
|
|
|
# of predicted words
|
|
|
|
# If we under-segment, we'll have one predicted word that covers a
|
|
|
|
# sequence of gold words.
|
|
|
|
# If we "mis-segment", we'll have a sequence of predicted words covering
|
|
|
|
# a sequence of gold words. That's many-to-many -- we don't do that.
|
|
|
|
cost, i2j, j2i, i2j_multi, j2i_multi = align([t.orth_ for t in doc], words)
|
|
|
|
|
|
|
|
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
|
|
|
|
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
|
2015-05-24 18:35:49 +03:00
|
|
|
|
2016-10-16 00:53:29 +03:00
|
|
|
annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
|
2015-07-28 15:44:53 +03:00
|
|
|
self.orig_annot = list(zip(*annot_tuples))
|
2015-03-09 08:46:22 +03:00
|
|
|
|
2015-05-24 18:35:49 +03:00
|
|
|
for i, gold_i in enumerate(self.cand_to_gold):
|
2016-10-16 12:41:36 +03:00
|
|
|
if doc[i].text.isspace():
|
2016-11-25 17:57:59 +03:00
|
|
|
self.words[i] = doc[i].text
|
2018-03-27 20:23:02 +03:00
|
|
|
self.tags[i] = '_SP'
|
2015-07-09 14:30:41 +03:00
|
|
|
self.heads[i] = None
|
|
|
|
self.labels[i] = None
|
|
|
|
self.ner[i] = 'O'
|
2016-10-16 12:41:36 +03:00
|
|
|
if gold_i is None:
|
2018-03-27 20:23:02 +03:00
|
|
|
if i in i2j_multi:
|
|
|
|
self.words[i] = words[i2j_multi[i]]
|
|
|
|
self.tags[i] = tags[i2j_multi[i]]
|
|
|
|
is_last = i2j_multi[i] != i2j_multi.get(i+1)
|
|
|
|
is_first = i2j_multi[i] != i2j_multi.get(i-1)
|
|
|
|
# Set next word in multi-token span as head, until last
|
|
|
|
if not is_last:
|
|
|
|
self.heads[i] = i+1
|
|
|
|
self.labels[i] = 'subtok'
|
|
|
|
else:
|
|
|
|
self.heads[i] = self.gold_to_cand[heads[i2j_multi[i]]]
|
|
|
|
self.labels[i] = deps[i2j_multi[i]]
|
|
|
|
# Now set NER...This is annoying because if we've split
|
|
|
|
# got an entity word split into two, we need to adjust the
|
|
|
|
# BILOU tags. We can't have BB or LL etc.
|
|
|
|
# Case 1: O -- easy.
|
|
|
|
ner_tag = entities[i2j_multi[i]]
|
|
|
|
if ner_tag == 'O':
|
|
|
|
self.ner[i] = 'O'
|
|
|
|
# Case 2: U. This has to become a B I* L sequence.
|
|
|
|
elif ner_tag.startswith('U-'):
|
|
|
|
if is_first:
|
|
|
|
self.ner[i] = ner_tag.replace('U-', 'B-', 1)
|
|
|
|
elif is_last:
|
|
|
|
self.ner[i] = ner_tag.replace('U-', 'L-', 1)
|
|
|
|
else:
|
|
|
|
self.ner[i] = ner_tag.replace('U-', 'I-', 1)
|
|
|
|
# Case 3: L. If not last, change to I.
|
|
|
|
elif ner_tag.startswith('L-'):
|
|
|
|
if is_last:
|
|
|
|
self.ner[i] = ner_tag
|
|
|
|
else:
|
|
|
|
self.ner[i] = ner_tag.replace('L-', 'I-', 1)
|
|
|
|
# Case 4: I. Stays correct
|
|
|
|
elif ner_tag.startswith('I-'):
|
|
|
|
self.ner[i] = ner_tag
|
2015-05-24 03:49:56 +03:00
|
|
|
else:
|
2016-11-25 17:57:59 +03:00
|
|
|
self.words[i] = words[gold_i]
|
2016-10-16 00:53:29 +03:00
|
|
|
self.tags[i] = tags[gold_i]
|
2017-05-26 01:15:09 +03:00
|
|
|
if heads[gold_i] is None:
|
|
|
|
self.heads[i] = None
|
|
|
|
else:
|
|
|
|
self.heads[i] = self.gold_to_cand[heads[gold_i]]
|
2016-10-16 00:53:29 +03:00
|
|
|
self.labels[i] = deps[gold_i]
|
|
|
|
self.ner[i] = entities[gold_i]
|
2016-02-22 16:40:40 +03:00
|
|
|
|
|
|
|
cycle = nonproj.contains_cycle(self.heads)
|
2017-10-27 18:02:55 +03:00
|
|
|
if cycle is not None:
|
2018-04-03 16:50:31 +03:00
|
|
|
raise ValueError(Errors.E069.format(cycle=cycle))
|
2016-02-22 16:40:40 +03:00
|
|
|
|
2015-03-10 20:00:23 +03:00
|
|
|
def __len__(self):
|
2017-05-21 14:53:46 +03:00
|
|
|
"""Get the number of gold-standard tokens.
|
2017-03-15 17:29:42 +03:00
|
|
|
|
2017-05-21 14:53:46 +03:00
|
|
|
RETURNS (int): The number of gold-standard tokens.
|
2016-11-01 14:25:36 +03:00
|
|
|
"""
|
2015-03-10 20:00:23 +03:00
|
|
|
return self.length
|
2015-03-09 08:46:22 +03:00
|
|
|
|
2015-05-30 02:25:46 +03:00
|
|
|
@property
|
|
|
|
def is_projective(self):
|
2017-05-21 14:53:46 +03:00
|
|
|
"""Whether the provided syntactic annotations form a projective
|
|
|
|
dependency tree.
|
2017-04-15 12:59:21 +03:00
|
|
|
"""
|
2016-02-24 13:26:25 +03:00
|
|
|
return not nonproj.is_nonproj_tree(self.heads)
|
2015-05-30 02:25:46 +03:00
|
|
|
|
2017-08-26 04:03:14 +03:00
|
|
|
@property
|
|
|
|
def sent_starts(self):
|
|
|
|
return [self.c.sent_start[i] for i in range(self.length)]
|
|
|
|
|
2015-02-21 19:06:58 +03:00
|
|
|
|
2017-07-29 22:58:37 +03:00
|
|
|
def biluo_tags_from_offsets(doc, entities, missing='O'):
|
2017-10-27 18:02:55 +03:00
|
|
|
"""Encode labelled spans into per-token tags, using the
|
|
|
|
Begin/In/Last/Unit/Out scheme (BILUO).
|
2017-05-21 14:53:46 +03:00
|
|
|
|
|
|
|
doc (Doc): The document that the entity offsets refer to. The output tags
|
|
|
|
will refer to the token boundaries within the document.
|
2017-10-27 18:02:55 +03:00
|
|
|
entities (iterable): A sequence of `(start, end, label)` triples. `start`
|
|
|
|
and `end` should be character-offset integers denoting the slice into
|
|
|
|
the original string.
|
2017-05-21 14:53:46 +03:00
|
|
|
RETURNS (list): A list of unicode strings, describing the tags. Each tag
|
|
|
|
string will be of the form either "", "O" or "{action}-{label}", where
|
|
|
|
action is one of "B", "I", "L", "U". The string "-" is used where the
|
2017-10-27 18:02:55 +03:00
|
|
|
entity offsets don't align with the tokenization in the `Doc` object.
|
|
|
|
The training algorithm will view these as missing values. "O" denotes a
|
2017-05-21 14:53:46 +03:00
|
|
|
non-entity token. "B" denotes the beginning of a multi-token entity,
|
|
|
|
"I" the inside of an entity of three or more tokens, and "L" the end
|
|
|
|
of an entity of two or more tokens. "U" denotes a single-token entity.
|
|
|
|
|
|
|
|
EXAMPLE:
|
|
|
|
>>> text = 'I like London.'
|
|
|
|
>>> entities = [(len('I like '), len('I like London'), 'LOC')]
|
|
|
|
>>> doc = nlp.tokenizer(text)
|
|
|
|
>>> tags = biluo_tags_from_offsets(doc, entities)
|
|
|
|
>>> assert tags == ['O', 'O', 'U-LOC', 'O']
|
2017-04-15 12:59:21 +03:00
|
|
|
"""
|
2016-10-15 22:51:04 +03:00
|
|
|
starts = {token.idx: token.i for token in doc}
|
|
|
|
ends = {token.idx+len(token): token.i for token in doc}
|
2016-10-16 12:41:36 +03:00
|
|
|
biluo = ['-' for _ in doc]
|
2016-10-15 22:51:04 +03:00
|
|
|
# Handle entity cases
|
|
|
|
for start_char, end_char, label in entities:
|
|
|
|
start_token = starts.get(start_char)
|
|
|
|
end_token = ends.get(end_char)
|
|
|
|
# Only interested if the tokenization is correct
|
|
|
|
if start_token is not None and end_token is not None:
|
|
|
|
if start_token == end_token:
|
|
|
|
biluo[start_token] = 'U-%s' % label
|
|
|
|
else:
|
|
|
|
biluo[start_token] = 'B-%s' % label
|
|
|
|
for i in range(start_token+1, end_token):
|
|
|
|
biluo[i] = 'I-%s' % label
|
|
|
|
biluo[end_token] = 'L-%s' % label
|
|
|
|
# Now distinguish the O cases from ones where we miss the tokenization
|
|
|
|
entity_chars = set()
|
|
|
|
for start_char, end_char, label in entities:
|
|
|
|
for i in range(start_char, end_char):
|
|
|
|
entity_chars.add(i)
|
|
|
|
for token in doc:
|
|
|
|
for i in range(token.idx, token.idx+len(token)):
|
|
|
|
if i in entity_chars:
|
|
|
|
break
|
|
|
|
else:
|
2017-07-29 22:58:37 +03:00
|
|
|
biluo[token.i] = missing
|
2016-10-15 22:51:04 +03:00
|
|
|
return biluo
|
|
|
|
|
|
|
|
|
2017-11-26 18:38:01 +03:00
|
|
|
def offsets_from_biluo_tags(doc, tags):
|
|
|
|
"""Encode per-token tags following the BILUO scheme into entity offsets.
|
|
|
|
|
|
|
|
doc (Doc): The document that the BILUO tags refer to.
|
|
|
|
entities (iterable): A sequence of BILUO tags with each tag describing one
|
|
|
|
token. Each tags string will be of the form of either "", "O" or
|
|
|
|
"{action}-{label}", where action is one of "B", "I", "L", "U".
|
|
|
|
RETURNS (list): A sequence of `(start, end, label)` triples. `start` and
|
|
|
|
`end` will be character-offset integers denoting the slice into the
|
|
|
|
original string.
|
|
|
|
"""
|
|
|
|
token_offsets = tags_to_entities(tags)
|
|
|
|
offsets = []
|
|
|
|
for label, start_idx, end_idx in token_offsets:
|
|
|
|
span = doc[start_idx : end_idx + 1]
|
|
|
|
offsets.append((span.start_char, span.end_char, label))
|
|
|
|
return offsets
|
|
|
|
|
|
|
|
|
2015-02-21 19:06:58 +03:00
|
|
|
def is_punct_label(label):
|
|
|
|
return label == 'P' or label.lower() == 'punct'
|