spaCy/spacy/gold.pyx

520 lines
19 KiB
Cython
Raw Normal View History

2017-03-15 17:29:42 +03:00
# cython: profile=True
# coding: utf8
from __future__ import unicode_literals, print_function
import re
import ujson
import random
2017-05-26 01:15:09 +03:00
import cytoolz
import itertools
import numpy
from . import _align
from .syntax import nonproj
from .tokens import Doc
2017-10-27 22:07:59 +03:00
from . import util
from .util import minibatch, itershuffle
2015-02-21 19:06:58 +03:00
def tags_to_entities(tags):
entities = []
start = None
for i, tag in enumerate(tags):
2016-11-25 17:57:59 +03:00
if tag is None:
continue
if tag.startswith('O'):
# TODO: We shouldn't be getting these malformed inputs. Fix this.
if start is not None:
start = None
continue
elif tag == '-':
continue
elif tag.startswith('I'):
assert start is not None, tags[:i]
continue
if tag.startswith('U'):
entities.append((tag[2:], i, i))
elif tag.startswith('B'):
start = i
elif tag.startswith('L'):
entities.append((tag[2:], start, i))
start = None
else:
raise Exception(tag)
return entities
def merge_sents(sents):
m_deps = [[], [], [], [], [], []]
m_brackets = []
i = 0
for (ids, words, tags, heads, labels, ner), brackets in sents:
m_deps[0].extend(id_ + i for id_ in ids)
m_deps[1].extend(words)
m_deps[2].extend(tags)
m_deps[3].extend(head + i for head in heads)
m_deps[4].extend(labels)
m_deps[5].extend(ner)
2017-10-27 18:02:55 +03:00
m_brackets.extend((b['first'] + i, b['last'] + i, b['label'])
for b in brackets)
i += len(ids)
return [(m_deps, m_brackets)]
punct_re = re.compile(r'\W')
def align(cand_words, gold_words):
2015-05-27 20:13:11 +03:00
if cand_words == gold_words:
alignment = numpy.arange(len(cand_words))
return 0, alignment, alignment, {}, {}
2018-02-25 16:55:00 +03:00
cand_words = [w.replace(' ', '') for w in cand_words]
gold_words = [w.replace(' ', '') for w in gold_words]
cost, i2j, j2i, matrix = _align.align(cand_words, gold_words)
i2j_multi, j2i_multi = _align.multi_align(i2j, j2i, [len(w) for w in cand_words],
[len(w) for w in gold_words])
for i, j in list(i2j_multi.items()):
if i2j_multi.get(i+1) != j and i2j_multi.get(i-1) != j:
i2j[i] = j
i2j_multi.pop(i)
for j, i in list(j2i_multi.items()):
if j2i_multi.get(j+1) != i and j2i_multi.get(j-1) != i:
j2i[j] = i
j2i_multi.pop(j)
return cost, i2j, j2i, i2j_multi, j2i_multi
2015-05-27 20:13:11 +03:00
class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER."""
def __init__(self, train_path, dev_path, gold_preproc=True, limit=None):
"""Create a GoldCorpus.
train_path (unicode or Path): File or directory of training data.
dev_path (unicode or Path): File or directory of development data.
2017-10-27 18:02:55 +03:00
RETURNS (GoldCorpus): The newly created object.
"""
self.train_path = util.ensure_path(train_path)
self.dev_path = util.ensure_path(dev_path)
2017-05-22 18:40:46 +03:00
self.limit = limit
self.train_locs = self.walk_corpus(self.train_path)
2017-05-21 21:38:46 +03:00
self.dev_locs = self.walk_corpus(self.dev_path)
@property
def train_tuples(self):
2017-05-22 18:40:46 +03:00
i = 0
for loc in self.train_locs:
gold_tuples = read_json_file(loc)
2017-05-22 18:40:46 +03:00
for item in gold_tuples:
yield item
2017-06-05 04:18:20 +03:00
i += len(item[1])
2017-05-22 18:40:46 +03:00
if self.limit and i >= self.limit:
break
@property
def dev_tuples(self):
2017-05-22 18:40:46 +03:00
i = 0
for loc in self.dev_locs:
gold_tuples = read_json_file(loc)
2017-05-22 18:40:46 +03:00
for item in gold_tuples:
yield item
i += len(item[1])
2017-05-22 18:40:46 +03:00
if self.limit and i >= self.limit:
break
def count_train(self):
n = 0
2017-06-05 04:18:20 +03:00
i = 0
for raw_text, paragraph_tuples in self.train_tuples:
2017-06-05 04:18:20 +03:00
n += sum([len(s[0][1]) for s in paragraph_tuples])
if self.limit and i >= self.limit:
break
i += len(paragraph_tuples)
return n
def train_docs(self, nlp, gold_preproc=False,
2017-06-05 04:16:57 +03:00
projectivize=False, max_length=None,
noise_level=0.0):
2017-05-22 02:44:07 +03:00
if projectivize:
train_tuples = nonproj.preprocess_training_data(
self.train_tuples, label_freq_cutoff=30)
random.shuffle(self.train_locs)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
2017-06-05 04:16:57 +03:00
max_length=max_length,
noise_level=noise_level)
yield from itershuffle(gold_docs, bufsize=100)
def dev_docs(self, nlp, gold_preproc=False):
2017-05-22 18:40:46 +03:00
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples, gold_preproc)
yield from gold_docs
@classmethod
2017-06-05 04:16:57 +03:00
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
noise_level=0.0):
for raw_text, paragraph_tuples in tuples:
if gold_preproc:
raw_text = None
else:
paragraph_tuples = merge_sents(paragraph_tuples)
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
2017-06-05 04:16:57 +03:00
gold_preproc, noise_level=noise_level)
golds = cls._make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
2017-06-03 21:28:52 +03:00
if (not max_length) or len(doc) < max_length:
yield doc, gold
@classmethod
2017-06-05 04:16:57 +03:00
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc,
noise_level=0.0):
if raw_text is not None:
2017-06-05 04:16:57 +03:00
raw_text = add_noise(raw_text, noise_level)
return [nlp.make_doc(raw_text)]
else:
2017-10-27 18:02:55 +03:00
return [Doc(nlp.vocab,
words=add_noise(sent_tuples[1], noise_level))
for (sent_tuples, brackets) in paragraph_tuples]
@classmethod
def _make_golds(cls, docs, paragraph_tuples):
2017-06-03 21:28:52 +03:00
assert len(docs) == len(paragraph_tuples)
if len(docs) == 1:
2017-10-27 18:02:55 +03:00
return [GoldParse.from_annot_tuples(docs[0],
paragraph_tuples[0][0])]
else:
2017-06-03 21:28:52 +03:00
return [GoldParse.from_annot_tuples(doc, sent_tuples)
2017-10-27 18:02:55 +03:00
for doc, (sent_tuples, brackets)
in zip(docs, paragraph_tuples)]
@staticmethod
def walk_corpus(path):
if not path.is_dir():
return [path]
paths = [path]
locs = []
seen = set()
for path in paths:
if str(path) in seen:
continue
seen.add(str(path))
if path.parts[-1].startswith('.'):
continue
elif path.is_dir():
paths.extend(path.iterdir())
elif path.parts[-1].endswith('.json'):
locs.append(path)
return locs
2017-06-05 04:16:57 +03:00
def add_noise(orig, noise_level):
if random.random() >= noise_level:
return orig
elif type(orig) == list:
corrupted = [_corrupt(word, noise_level) for word in orig]
corrupted = [w for w in corrupted if w]
return corrupted
else:
return ''.join(_corrupt(c, noise_level) for c in orig)
def _corrupt(c, noise_level):
if random.random() >= noise_level:
return c
elif c == ' ':
return '\n'
elif c == '\n':
return ' '
elif c in ['.', "'", "!", "?"]:
return ''
else:
return c.lower()
def read_json_file(loc, docs_filter=None, limit=None):
2017-10-27 22:07:59 +03:00
loc = util.ensure_path(loc)
2017-04-15 13:13:00 +03:00
if loc.is_dir():
for filename in loc.iterdir():
yield from read_json_file(loc / filename, limit=limit)
else:
2017-04-16 21:00:37 +03:00
with loc.open('r', encoding='utf8') as file_:
2017-04-15 13:13:34 +03:00
docs = ujson.load(file_)
if limit is not None:
docs = docs[:limit]
for doc in docs:
if docs_filter is not None and not docs_filter(doc):
continue
paragraphs = []
for paragraph in doc['paragraphs']:
sents = []
for sent in paragraph['sentences']:
words = []
ids = []
tags = []
heads = []
labels = []
ner = []
for i, token in enumerate(sent['tokens']):
words.append(token['orth'])
ids.append(i)
2017-10-27 18:02:55 +03:00
tags.append(token.get('tag', '-'))
heads.append(token.get('head', 0) + i)
labels.append(token.get('dep', ''))
# Ensure ROOT label is case-insensitive
if labels[-1].lower() == 'root':
labels[-1] = 'ROOT'
ner.append(token.get('ner', '-'))
sents.append([
[ids, words, tags, heads, labels, ner],
2017-10-27 18:02:55 +03:00
sent.get('brackets', [])])
if sents:
yield [paragraph.get('raw', None), sents]
2015-05-06 17:27:31 +03:00
2017-05-26 19:32:55 +03:00
def iob_to_biluo(tags):
out = []
curr_label = None
tags = list(tags)
while tags:
out.extend(_consume_os(tags))
out.extend(_consume_ent(tags))
return out
def _consume_os(tags):
while tags and tags[0] == 'O':
yield tags.pop(0)
def _consume_ent(tags):
if not tags:
return []
target = tags.pop(0).replace('B', 'I')
length = 1
while tags and tags[0] == target:
length += 1
tags.pop(0)
label = target[2:]
if length == 1:
return ['U-' + label]
else:
start = 'B-' + label
end = 'L-' + label
middle = ['I-%s' % label for _ in range(1, length - 1)]
return [start] + middle + [end]
2015-03-09 08:46:22 +03:00
cdef class GoldParse:
2016-11-01 14:25:36 +03:00
"""Collection for training annotations."""
@classmethod
def from_annot_tuples(cls, doc, annot_tuples, make_projective=False):
_, words, tags, heads, deps, entities = annot_tuples
2017-10-27 18:02:55 +03:00
return cls(doc, words=words, tags=tags, heads=heads, deps=deps,
entities=entities, make_projective=make_projective)
2017-10-27 18:02:55 +03:00
def __init__(self, doc, annot_tuples=None, words=None, tags=None,
heads=None, deps=None, entities=None, make_projective=False,
cats=None):
"""Create a GoldParse.
doc (Doc): The document the annotations refer to.
words (iterable): A sequence of unicode word strings.
tags (iterable): A sequence of strings, representing tag annotations.
2017-10-27 18:02:55 +03:00
heads (iterable): A sequence of integers, representing syntactic
head offsets.
deps (iterable): A sequence of strings, representing the syntactic
relation types.
entities (iterable): A sequence of named entity annotations, either as
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
representing the entity positions.
cats (dict): Labels for text classification. Each key in the dictionary
may be a string or an int, or a `(start_char, end_char, label)`
tuple, indicating that the label is applied to only part of the
document (usually a sentence). Unlike entity annotations, label
annotations can overlap, i.e. a single word can be covered by
multiple labelled spans. The TextCategorizer component expects
2017-10-27 18:02:55 +03:00
true examples of a label to have the value 1.0, and negative
examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels
will be zero.
RETURNS (GoldParse): The newly constructed object.
2016-11-01 14:25:36 +03:00
"""
if words is None:
words = [token.text for token in doc]
if tags is None:
tags = [None for _ in doc]
if heads is None:
heads = [None for token in doc]
if deps is None:
deps = [None for _ in doc]
if entities is None:
entities = [None for _ in doc]
elif len(entities) == 0:
entities = ['O' for _ in doc]
elif not isinstance(entities[0], basestring):
# Assume we have entities specified by character offset.
entities = biluo_tags_from_offsets(doc, entities)
2015-03-09 08:46:22 +03:00
self.mem = Pool()
self.loss = 0
2016-10-16 00:55:07 +03:00
self.length = len(doc)
2015-03-09 08:46:22 +03:00
# These are filled by the tagger/parser/entity recogniser
self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
2017-05-28 15:06:40 +03:00
self.c.labels = <attr_t*>self.mem.alloc(len(doc), sizeof(attr_t))
2017-05-30 21:37:24 +03:00
self.c.has_dep = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
2015-03-09 08:46:22 +03:00
self.cats = {} if cats is None else dict(cats)
2016-11-25 17:57:59 +03:00
self.words = [None] * len(doc)
self.tags = [None] * len(doc)
self.heads = [None] * len(doc)
self.labels = [None] * len(doc)
self.ner = [None] * len(doc)
# Do many-to-one alignment for misaligned tokens.
# If we over-segment, we'll have one gold word that covers a sequence
# of predicted words
# If we under-segment, we'll have one predicted word that covers a
# sequence of gold words.
# If we "mis-segment", we'll have a sequence of predicted words covering
# a sequence of gold words. That's many-to-many -- we don't do that.
cost, i2j, j2i, i2j_multi, j2i_multi = align([t.orth_ for t in doc], words)
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
2015-07-28 15:44:53 +03:00
self.orig_annot = list(zip(*annot_tuples))
2015-03-09 08:46:22 +03:00
for i, gold_i in enumerate(self.cand_to_gold):
2016-10-16 12:41:36 +03:00
if doc[i].text.isspace():
2016-11-25 17:57:59 +03:00
self.words[i] = doc[i].text
2018-02-24 18:03:50 +03:00
self.tags[i] = '_SP'
self.heads[i] = None
self.labels[i] = None
self.ner[i] = 'O'
2016-10-16 12:41:36 +03:00
if gold_i is None:
if i in i2j_multi:
self.words[i] = words[i2j_multi[i]]
self.tags[i] = tags[i2j_multi[i]]
# Set next word in multi-token span as head, until last
if i2j_multi[i] == i2j_multi.get(i+1):
self.heads[i] = i+1
self.labels[i] = 'subtok'
else:
self.heads[i] = self.gold_to_cand[heads[i2j_multi[i]]]
self.labels[i] = deps[i2j_multi[i]]
# TODO: Set NER!
else:
2016-11-25 17:57:59 +03:00
self.words[i] = words[gold_i]
self.tags[i] = tags[gold_i]
2017-05-26 01:15:09 +03:00
if heads[gold_i] is None:
self.heads[i] = None
else:
self.heads[i] = self.gold_to_cand[heads[gold_i]]
self.labels[i] = deps[gold_i]
self.ner[i] = entities[gold_i]
cycle = nonproj.contains_cycle(self.heads)
2017-10-27 18:02:55 +03:00
if cycle is not None:
raise Exception("Cycle found: %s" % cycle)
if make_projective:
2017-10-27 18:02:55 +03:00
proj_heads, _ = nonproj.projectivize(self.heads, self.labels)
self.heads = proj_heads
2015-06-23 01:02:22 +03:00
def __len__(self):
"""Get the number of gold-standard tokens.
2017-03-15 17:29:42 +03:00
RETURNS (int): The number of gold-standard tokens.
2016-11-01 14:25:36 +03:00
"""
return self.length
2015-03-09 08:46:22 +03:00
@property
def is_projective(self):
"""Whether the provided syntactic annotations form a projective
dependency tree.
"""
return not nonproj.is_nonproj_tree(self.heads)
@property
def sent_starts(self):
return [self.c.sent_start[i] for i in range(self.length)]
2015-02-21 19:06:58 +03:00
def biluo_tags_from_offsets(doc, entities, missing='O'):
2017-10-27 18:02:55 +03:00
"""Encode labelled spans into per-token tags, using the
Begin/In/Last/Unit/Out scheme (BILUO).
doc (Doc): The document that the entity offsets refer to. The output tags
will refer to the token boundaries within the document.
2017-10-27 18:02:55 +03:00
entities (iterable): A sequence of `(start, end, label)` triples. `start`
and `end` should be character-offset integers denoting the slice into
the original string.
RETURNS (list): A list of unicode strings, describing the tags. Each tag
string will be of the form either "", "O" or "{action}-{label}", where
action is one of "B", "I", "L", "U". The string "-" is used where the
2017-10-27 18:02:55 +03:00
entity offsets don't align with the tokenization in the `Doc` object.
The training algorithm will view these as missing values. "O" denotes a
non-entity token. "B" denotes the beginning of a multi-token entity,
"I" the inside of an entity of three or more tokens, and "L" the end
of an entity of two or more tokens. "U" denotes a single-token entity.
EXAMPLE:
>>> text = 'I like London.'
>>> entities = [(len('I like '), len('I like London'), 'LOC')]
>>> doc = nlp.tokenizer(text)
>>> tags = biluo_tags_from_offsets(doc, entities)
>>> assert tags == ['O', 'O', 'U-LOC', 'O']
"""
starts = {token.idx: token.i for token in doc}
ends = {token.idx+len(token): token.i for token in doc}
2016-10-16 12:41:36 +03:00
biluo = ['-' for _ in doc]
# Handle entity cases
for start_char, end_char, label in entities:
start_token = starts.get(start_char)
end_token = ends.get(end_char)
# Only interested if the tokenization is correct
if start_token is not None and end_token is not None:
if start_token == end_token:
biluo[start_token] = 'U-%s' % label
else:
biluo[start_token] = 'B-%s' % label
for i in range(start_token+1, end_token):
biluo[i] = 'I-%s' % label
biluo[end_token] = 'L-%s' % label
# Now distinguish the O cases from ones where we miss the tokenization
entity_chars = set()
for start_char, end_char, label in entities:
for i in range(start_char, end_char):
entity_chars.add(i)
for token in doc:
for i in range(token.idx, token.idx+len(token)):
if i in entity_chars:
break
else:
biluo[token.i] = missing
return biluo
def offsets_from_biluo_tags(doc, tags):
"""Encode per-token tags following the BILUO scheme into entity offsets.
doc (Doc): The document that the BILUO tags refer to.
entities (iterable): A sequence of BILUO tags with each tag describing one
token. Each tags string will be of the form of either "", "O" or
"{action}-{label}", where action is one of "B", "I", "L", "U".
RETURNS (list): A sequence of `(start, end, label)` triples. `start` and
`end` will be character-offset integers denoting the slice into the
original string.
"""
token_offsets = tags_to_entities(tags)
offsets = []
for label, start_idx, end_idx in token_offsets:
span = doc[start_idx : end_idx + 1]
offsets.append((span.start_char, span.end_char, label))
return offsets
2015-02-21 19:06:58 +03:00
def is_punct_label(label):
return label == 'P' or label.lower() == 'punct'