Merge branch 'develop' into spacy.io

This commit is contained in:
Ines Montani 2019-03-11 01:38:54 +01:00
commit 6768dfd6e7
8 changed files with 121 additions and 55 deletions

View File

@ -97,6 +97,7 @@ def with_cpu(ops, model):
"""Wrap a model that should run on CPU, transferring inputs and outputs
as necessary."""
model.to_cpu()
def with_cpu_forward(inputs, drop=0.):
cpu_outputs, backprop = model.begin_update(_to_cpu(inputs), drop=drop)
gpu_outputs = _to_device(ops, cpu_outputs)

View File

@ -16,7 +16,7 @@ TAG_MAP = {
":": {POS: PUNCT},
"$": {POS: SYM, "Other": {"SymType": "currency"}},
"#": {POS: SYM, "Other": {"SymType": "numbersign"}},
"AFX": {POS: ADJ, "Hyph": "yes"},
"AFX": {POS: X, "Hyph": "yes"},
"CC": {POS: CCONJ, "ConjType": "coor"},
"CD": {POS: NUM, "NumType": "card"},
"DT": {POS: DET},
@ -34,10 +34,10 @@ TAG_MAP = {
"NNP": {POS: PROPN, "NounType": "prop", "Number": "sing"},
"NNPS": {POS: PROPN, "NounType": "prop", "Number": "plur"},
"NNS": {POS: NOUN, "Number": "plur"},
"PDT": {POS: ADJ, "AdjType": "pdt", "PronType": "prn"},
"PDT": {POS: DET, "AdjType": "pdt", "PronType": "prn"},
"POS": {POS: PART, "Poss": "yes"},
"PRP": {POS: PRON, "PronType": "prs"},
"PRP$": {POS: ADJ, "PronType": "prs", "Poss": "yes"},
"PRP$": {POS: DET, "PronType": "prs", "Poss": "yes"},
"RB": {POS: ADV, "Degree": "pos"},
"RBR": {POS: ADV, "Degree": "comp"},
"RBS": {POS: ADV, "Degree": "sup"},
@ -58,9 +58,9 @@ TAG_MAP = {
"Number": "sing",
"Person": 3,
},
"WDT": {POS: ADJ, "PronType": "int|rel"},
"WP": {POS: NOUN, "PronType": "int|rel"},
"WP$": {POS: ADJ, "Poss": "yes", "PronType": "int|rel"},
"WDT": {POS: DET, "PronType": "int|rel"},
"WP": {POS: PRON, "PronType": "int|rel"},
"WP$": {POS: DET, "Poss": "yes", "PronType": "int|rel"},
"WRB": {POS: ADV, "PronType": "int|rel"},
"ADD": {POS: X},
"NFP": {POS: PUNCT},

View File

@ -106,6 +106,7 @@ class Language(object):
DOCS: https://spacy.io/api/language
"""
Defaults = BaseDefaults
lang = None
@ -344,13 +345,15 @@ class Language(object):
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
return self.pipeline.pop(self.pipe_names.index(name))
def __call__(self, text, disable=[]):
def __call__(self, text, disable=[], component_cfg=None):
"""Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string
is preserved.
text (unicode): The text to be processed.
disable (list): Names of the pipeline components to disable.
component_cfg (dict): An optional dictionary with extra keyword arguments
for specific components.
RETURNS (Doc): A container for accessing the annotations.
EXAMPLE:
@ -363,12 +366,14 @@ class Language(object):
Errors.E088.format(length=len(text), max_length=self.max_length)
)
doc = self.make_doc(text)
if component_cfg is None:
component_cfg = {}
for name, proc in self.pipeline:
if name in disable:
continue
if not hasattr(proc, "__call__"):
raise ValueError(Errors.E003.format(component=type(proc), name=name))
doc = proc(doc)
doc = proc(doc, **component_cfg.get(name, {}))
if doc is None:
raise ValueError(Errors.E005.format(name=name))
return doc
@ -396,7 +401,7 @@ class Language(object):
def make_doc(self, text):
return self.tokenizer(text)
def update(self, docs, golds, drop=0.0, sgd=None, losses=None):
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
"""Update the models in the pipeline.
docs (iterable): A batch of `Doc` objects.
@ -443,11 +448,15 @@ class Language(object):
pipes = list(self.pipeline)
random.shuffle(pipes)
if component_cfg is None:
component_cfg = {}
for name, proc in pipes:
if not hasattr(proc, "update"):
continue
grads = {}
proc.update(docs, golds, drop=drop, sgd=get_grads, losses=losses)
kwargs = component_cfg.get(name, {})
kwargs.setdefault("drop", drop)
proc.update(docs, golds, sgd=get_grads, losses=losses, **kwargs)
for key, (W, dW) in grads.items():
sgd(W, dW, key=key)
@ -517,11 +526,12 @@ class Language(object):
for doc, gold in docs_golds:
yield doc, gold
def begin_training(self, get_gold_tuples=None, sgd=None, **cfg):
def begin_training(self, get_gold_tuples=None, sgd=None, component_cfg=None, **cfg):
"""Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager.
get_gold_tuples (function): Function returning gold data
component_cfg (dict): Config parameters for specific components.
**cfg: Config parameters.
RETURNS: An optimizer
"""
@ -543,10 +553,17 @@ class Language(object):
if sgd is None:
sgd = create_default_optimizer(Model.ops)
self._optimizer = sgd
if component_cfg is None:
component_cfg = {}
for name, proc in self.pipeline:
if hasattr(proc, "begin_training"):
kwargs = component_cfg.get(name, {})
kwargs.update(cfg)
proc.begin_training(
get_gold_tuples, pipeline=self.pipeline, sgd=self._optimizer, **cfg
get_gold_tuples,
pipeline=self.pipeline,
sgd=self._optimizer,
**kwargs
)
return self._optimizer
@ -574,20 +591,27 @@ class Language(object):
proc._rehearsal_model = deepcopy(proc.model)
return self._optimizer
def evaluate(self, docs_golds, verbose=False, batch_size=256):
scorer = Scorer()
def evaluate(
self, docs_golds, verbose=False, batch_size=256, scorer=None, component_cfg=None
):
if scorer is None:
scorer = Scorer()
docs, golds = zip(*docs_golds)
docs = list(docs)
golds = list(golds)
for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size)
if not hasattr(pipe, "pipe"):
docs = (pipe(doc) for doc in docs)
docs = (pipe(doc, **kwargs) for doc in docs)
else:
docs = pipe.pipe(docs, batch_size=batch_size)
docs = pipe.pipe(docs, **kwargs)
for doc, gold in zip(docs, golds):
if verbose:
print(doc)
scorer.score(doc, gold, verbose=verbose)
kwargs = component_cfg.get("scorer", {})
kwargs.setdefault("verbose", verbose)
scorer.score(doc, gold, **kwargs)
return scorer
@contextmanager
@ -630,6 +654,7 @@ class Language(object):
batch_size=1000,
disable=[],
cleanup=False,
component_cfg=None,
):
"""Process texts as a stream, and yield `Doc` objects in order.
@ -643,6 +668,8 @@ class Language(object):
disable (list): Names of the pipeline components to disable.
cleanup (bool): If True, unneeded strings are freed,
to control memory use. Experimental.
component_cfg (dict): An optional dictionary with extra keyword arguments
for specific components.
YIELDS (Doc): Documents in the order of the original text.
EXAMPLE:
@ -655,20 +682,30 @@ class Language(object):
texts = (tc[0] for tc in text_context1)
contexts = (tc[1] for tc in text_context2)
docs = self.pipe(
texts, n_threads=n_threads, batch_size=batch_size, disable=disable
texts,
n_threads=n_threads,
batch_size=batch_size,
disable=disable,
component_cfg=component_cfg,
)
for doc, context in izip(docs, contexts):
yield (doc, context)
return
docs = (self.make_doc(text) for text in texts)
if component_cfg is None:
component_cfg = {}
for name, proc in self.pipeline:
if name in disable:
continue
kwargs = component_cfg.get(name, {})
# Allow component_cfg to overwrite the top-level kwargs.
kwargs.setdefault("batch_size", batch_size)
kwargs.setdefault("n_threads", n_threads)
if hasattr(proc, "pipe"):
docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size)
docs = proc.pipe(docs, **kwargs)
else:
# Apply the function, but yield the doc
docs = _pipe(proc, docs)
docs = _pipe(proc, docs, kwargs)
# Track weakrefs of "recent" documents, so that we can see when they
# expire from memory. When they do, we know we don't need old strings.
# This way, we avoid maintaining an unbounded growth in string entries
@ -861,7 +898,7 @@ class DisabledPipes(list):
self[:] = []
def _pipe(func, docs):
def _pipe(func, docs, kwargs):
for doc in docs:
doc = func(doc)
doc = func(doc, **kwargs)
yield doc

View File

@ -110,7 +110,8 @@ cdef class Morphology:
analysis.lemma = self.lemmatize(analysis.tag.pos, token.lex.orth,
self.tag_map.get(tag_str, {}))
self._cache.set(tag_id, token.lex.orth, analysis)
token.lemma = analysis.lemma
if token.lemma == 0:
token.lemma = analysis.lemma
token.pos = analysis.tag.pos
token.tag = analysis.tag.name
token.morph = analysis.tag.morph

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals
import pytest
import numpy
from spacy.tokens import Doc
from spacy.displacy import render
from spacy.gold import iob_to_biluo
@ -12,12 +13,14 @@ from spacy.lang.en import English
from ..util import add_vecs_to_vocab, get_doc
@pytest.mark.xfail(
reason="The dot is now properly split off, but the prefix/suffix rules are not applied again afterwards."
"This means that the quote will still be attached to the remaining token."
)
@pytest.mark.xfail
def test_issue2070():
"""Test that checks that a dot followed by a quote is handled appropriately."""
"""Test that checks that a dot followed by a quote is handled
appropriately.
"""
# Problem: The dot is now properly split off, but the prefix/suffix rules
# are not applied again afterwards. This means that the quote will still be
# attached to the remaining token.
nlp = English()
doc = nlp('First sentence."A quoted sentence" he said ...')
assert len(doc) == 11
@ -37,6 +40,26 @@ def test_issue2179():
assert nlp2.get_pipe("ner").labels == ("CITIZENSHIP",)
def test_issue2203(en_vocab):
"""Test that lemmas are set correctly in doc.from_array."""
words = ["I", "'ll", "survive"]
tags = ["PRP", "MD", "VB"]
lemmas = ["-PRON-", "will", "survive"]
tag_ids = [en_vocab.strings.add(tag) for tag in tags]
lemma_ids = [en_vocab.strings.add(lemma) for lemma in lemmas]
doc = Doc(en_vocab, words=words)
# Work around lemma corrpution problem and set lemmas after tags
doc.from_array("TAG", numpy.array(tag_ids, dtype="uint64"))
doc.from_array("LEMMA", numpy.array(lemma_ids, dtype="uint64"))
assert [t.tag_ for t in doc] == tags
assert [t.lemma_ for t in doc] == lemmas
# We need to serialize both tag and lemma, since this is what causes the bug
doc_array = doc.to_array(["TAG", "LEMMA"])
new_doc = Doc(doc.vocab, words=words).from_array(["TAG", "LEMMA"], doc_array)
assert [t.tag_ for t in new_doc] == tags
assert [t.lemma_ for t in new_doc] == lemmas
def test_issue2219(en_vocab):
vectors = [("a", [1, 2, 3]), ("letter", [4, 5, 6])]
add_vecs_to_vocab(en_vocab, vectors)

View File

@ -763,17 +763,18 @@ cdef class Doc:
attr_ids[i] = attr_id
if len(array.shape) == 1:
array = array.reshape((array.size, 1))
# Do TAG first. This lets subsequent loop override stuff like POS, LEMMA
if TAG in attrs:
col = attrs.index(TAG)
for i in range(length):
if array[i, col] != 0:
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
# Now load the data
for i in range(self.length):
token = &self.c[i]
for j in range(n_attrs):
Token.set_struct_attr(token, attr_ids[j], array[i, j])
# Auxiliary loading logic
for col, attr_id in enumerate(attrs):
if attr_id == TAG:
for i in range(length):
if array[i, col] != 0:
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
if attr_ids[j] != TAG:
Token.set_struct_attr(token, attr_ids[j], array[i, j])
# Set flags
self.is_parsed = bool(self.is_parsed or HEAD in attrs or DEP in attrs)
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)

View File

@ -91,13 +91,14 @@ multiprocessing.
> assert doc.is_parsed
> ```
| Name | Type | Description |
| ------------ | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `texts` | - | A sequence of unicode objects. |
| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
| `batch_size` | int | The number of texts to buffer. |
| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
| **YIELDS** | `Doc` | Documents in the order of the original text. |
| Name | Type | Description |
| -------------------------------------------- | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `texts` | - | A sequence of unicode objects. |
| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
| `batch_size` | int | The number of texts to buffer. |
| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
| `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
| **YIELDS** | `Doc` | Documents in the order of the original text. |
## Language.update {#update tag="method"}
@ -112,13 +113,14 @@ Update the models in the pipeline.
> nlp.update([doc], [gold], drop=0.5, sgd=optimizer)
> ```
| Name | Type | Description |
| ----------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. |
| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
| `drop` | float | The dropout rate. |
| `sgd` | callable | An optimizer. |
| **RETURNS** | dict | Results from the update. |
| Name | Type | Description |
| -------------------------------------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. |
| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
| `drop` | float | The dropout rate. |
| `sgd` | callable | An optimizer. |
| `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
| **RETURNS** | dict | Results from the update. |
## Language.begin_training {#begin_training tag="method"}
@ -130,11 +132,12 @@ Allocate models, pre-process training data and acquire an optimizer.
> optimizer = nlp.begin_training(gold_tuples)
> ```
| Name | Type | Description |
| ------------- | -------- | ---------------------------- |
| `gold_tuples` | iterable | Gold-standard training data. |
| `**cfg` | - | Config parameters. |
| **RETURNS** | callable | An optimizer. |
| Name | Type | Description |
| -------------------------------------------- | -------- | ---------------------------------------------------------------------------- |
| `gold_tuples` | iterable | Gold-standard training data. |
| `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
| `**cfg` | - | Config parameters (sent to all components). |
| **RETURNS** | callable | An optimizer. |
## Language.use_params {#use_params tag="contextmanager, method"}

View File

@ -283,7 +283,7 @@ from pathlib import Path
nlp = spacy.load("en_core_web_sm")
sentences = [u"This is an example.", u"This is another one."]
for sent in sentences:
doc = nlp(sentence)
doc = nlp(sent)
svg = displacy.render(doc, style="dep")
file_name = '-'.join([w.text for w in doc if not w.is_punct]) + ".svg"
output_path = Path("/images/" + file_name)