mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
Merge branch 'develop' into spacy.io
This commit is contained in:
commit
6768dfd6e7
|
@ -97,6 +97,7 @@ def with_cpu(ops, model):
|
||||||
"""Wrap a model that should run on CPU, transferring inputs and outputs
|
"""Wrap a model that should run on CPU, transferring inputs and outputs
|
||||||
as necessary."""
|
as necessary."""
|
||||||
model.to_cpu()
|
model.to_cpu()
|
||||||
|
|
||||||
def with_cpu_forward(inputs, drop=0.):
|
def with_cpu_forward(inputs, drop=0.):
|
||||||
cpu_outputs, backprop = model.begin_update(_to_cpu(inputs), drop=drop)
|
cpu_outputs, backprop = model.begin_update(_to_cpu(inputs), drop=drop)
|
||||||
gpu_outputs = _to_device(ops, cpu_outputs)
|
gpu_outputs = _to_device(ops, cpu_outputs)
|
||||||
|
|
|
@ -16,7 +16,7 @@ TAG_MAP = {
|
||||||
":": {POS: PUNCT},
|
":": {POS: PUNCT},
|
||||||
"$": {POS: SYM, "Other": {"SymType": "currency"}},
|
"$": {POS: SYM, "Other": {"SymType": "currency"}},
|
||||||
"#": {POS: SYM, "Other": {"SymType": "numbersign"}},
|
"#": {POS: SYM, "Other": {"SymType": "numbersign"}},
|
||||||
"AFX": {POS: ADJ, "Hyph": "yes"},
|
"AFX": {POS: X, "Hyph": "yes"},
|
||||||
"CC": {POS: CCONJ, "ConjType": "coor"},
|
"CC": {POS: CCONJ, "ConjType": "coor"},
|
||||||
"CD": {POS: NUM, "NumType": "card"},
|
"CD": {POS: NUM, "NumType": "card"},
|
||||||
"DT": {POS: DET},
|
"DT": {POS: DET},
|
||||||
|
@ -34,10 +34,10 @@ TAG_MAP = {
|
||||||
"NNP": {POS: PROPN, "NounType": "prop", "Number": "sing"},
|
"NNP": {POS: PROPN, "NounType": "prop", "Number": "sing"},
|
||||||
"NNPS": {POS: PROPN, "NounType": "prop", "Number": "plur"},
|
"NNPS": {POS: PROPN, "NounType": "prop", "Number": "plur"},
|
||||||
"NNS": {POS: NOUN, "Number": "plur"},
|
"NNS": {POS: NOUN, "Number": "plur"},
|
||||||
"PDT": {POS: ADJ, "AdjType": "pdt", "PronType": "prn"},
|
"PDT": {POS: DET, "AdjType": "pdt", "PronType": "prn"},
|
||||||
"POS": {POS: PART, "Poss": "yes"},
|
"POS": {POS: PART, "Poss": "yes"},
|
||||||
"PRP": {POS: PRON, "PronType": "prs"},
|
"PRP": {POS: PRON, "PronType": "prs"},
|
||||||
"PRP$": {POS: ADJ, "PronType": "prs", "Poss": "yes"},
|
"PRP$": {POS: DET, "PronType": "prs", "Poss": "yes"},
|
||||||
"RB": {POS: ADV, "Degree": "pos"},
|
"RB": {POS: ADV, "Degree": "pos"},
|
||||||
"RBR": {POS: ADV, "Degree": "comp"},
|
"RBR": {POS: ADV, "Degree": "comp"},
|
||||||
"RBS": {POS: ADV, "Degree": "sup"},
|
"RBS": {POS: ADV, "Degree": "sup"},
|
||||||
|
@ -58,9 +58,9 @@ TAG_MAP = {
|
||||||
"Number": "sing",
|
"Number": "sing",
|
||||||
"Person": 3,
|
"Person": 3,
|
||||||
},
|
},
|
||||||
"WDT": {POS: ADJ, "PronType": "int|rel"},
|
"WDT": {POS: DET, "PronType": "int|rel"},
|
||||||
"WP": {POS: NOUN, "PronType": "int|rel"},
|
"WP": {POS: PRON, "PronType": "int|rel"},
|
||||||
"WP$": {POS: ADJ, "Poss": "yes", "PronType": "int|rel"},
|
"WP$": {POS: DET, "Poss": "yes", "PronType": "int|rel"},
|
||||||
"WRB": {POS: ADV, "PronType": "int|rel"},
|
"WRB": {POS: ADV, "PronType": "int|rel"},
|
||||||
"ADD": {POS: X},
|
"ADD": {POS: X},
|
||||||
"NFP": {POS: PUNCT},
|
"NFP": {POS: PUNCT},
|
||||||
|
|
|
@ -106,6 +106,7 @@ class Language(object):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language
|
DOCS: https://spacy.io/api/language
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Defaults = BaseDefaults
|
Defaults = BaseDefaults
|
||||||
lang = None
|
lang = None
|
||||||
|
|
||||||
|
@ -344,13 +345,15 @@ class Language(object):
|
||||||
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
|
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
|
||||||
return self.pipeline.pop(self.pipe_names.index(name))
|
return self.pipeline.pop(self.pipe_names.index(name))
|
||||||
|
|
||||||
def __call__(self, text, disable=[]):
|
def __call__(self, text, disable=[], component_cfg=None):
|
||||||
"""Apply the pipeline to some text. The text can span multiple sentences,
|
"""Apply the pipeline to some text. The text can span multiple sentences,
|
||||||
and can contain arbtrary whitespace. Alignment into the original string
|
and can contain arbtrary whitespace. Alignment into the original string
|
||||||
is preserved.
|
is preserved.
|
||||||
|
|
||||||
text (unicode): The text to be processed.
|
text (unicode): The text to be processed.
|
||||||
disable (list): Names of the pipeline components to disable.
|
disable (list): Names of the pipeline components to disable.
|
||||||
|
component_cfg (dict): An optional dictionary with extra keyword arguments
|
||||||
|
for specific components.
|
||||||
RETURNS (Doc): A container for accessing the annotations.
|
RETURNS (Doc): A container for accessing the annotations.
|
||||||
|
|
||||||
EXAMPLE:
|
EXAMPLE:
|
||||||
|
@ -363,12 +366,14 @@ class Language(object):
|
||||||
Errors.E088.format(length=len(text), max_length=self.max_length)
|
Errors.E088.format(length=len(text), max_length=self.max_length)
|
||||||
)
|
)
|
||||||
doc = self.make_doc(text)
|
doc = self.make_doc(text)
|
||||||
|
if component_cfg is None:
|
||||||
|
component_cfg = {}
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if name in disable:
|
if name in disable:
|
||||||
continue
|
continue
|
||||||
if not hasattr(proc, "__call__"):
|
if not hasattr(proc, "__call__"):
|
||||||
raise ValueError(Errors.E003.format(component=type(proc), name=name))
|
raise ValueError(Errors.E003.format(component=type(proc), name=name))
|
||||||
doc = proc(doc)
|
doc = proc(doc, **component_cfg.get(name, {}))
|
||||||
if doc is None:
|
if doc is None:
|
||||||
raise ValueError(Errors.E005.format(name=name))
|
raise ValueError(Errors.E005.format(name=name))
|
||||||
return doc
|
return doc
|
||||||
|
@ -396,7 +401,7 @@ class Language(object):
|
||||||
def make_doc(self, text):
|
def make_doc(self, text):
|
||||||
return self.tokenizer(text)
|
return self.tokenizer(text)
|
||||||
|
|
||||||
def update(self, docs, golds, drop=0.0, sgd=None, losses=None):
|
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
|
||||||
"""Update the models in the pipeline.
|
"""Update the models in the pipeline.
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
docs (iterable): A batch of `Doc` objects.
|
||||||
|
@ -443,11 +448,15 @@ class Language(object):
|
||||||
|
|
||||||
pipes = list(self.pipeline)
|
pipes = list(self.pipeline)
|
||||||
random.shuffle(pipes)
|
random.shuffle(pipes)
|
||||||
|
if component_cfg is None:
|
||||||
|
component_cfg = {}
|
||||||
for name, proc in pipes:
|
for name, proc in pipes:
|
||||||
if not hasattr(proc, "update"):
|
if not hasattr(proc, "update"):
|
||||||
continue
|
continue
|
||||||
grads = {}
|
grads = {}
|
||||||
proc.update(docs, golds, drop=drop, sgd=get_grads, losses=losses)
|
kwargs = component_cfg.get(name, {})
|
||||||
|
kwargs.setdefault("drop", drop)
|
||||||
|
proc.update(docs, golds, sgd=get_grads, losses=losses, **kwargs)
|
||||||
for key, (W, dW) in grads.items():
|
for key, (W, dW) in grads.items():
|
||||||
sgd(W, dW, key=key)
|
sgd(W, dW, key=key)
|
||||||
|
|
||||||
|
@ -517,11 +526,12 @@ class Language(object):
|
||||||
for doc, gold in docs_golds:
|
for doc, gold in docs_golds:
|
||||||
yield doc, gold
|
yield doc, gold
|
||||||
|
|
||||||
def begin_training(self, get_gold_tuples=None, sgd=None, **cfg):
|
def begin_training(self, get_gold_tuples=None, sgd=None, component_cfg=None, **cfg):
|
||||||
"""Allocate models, pre-process training data and acquire a trainer and
|
"""Allocate models, pre-process training data and acquire a trainer and
|
||||||
optimizer. Used as a contextmanager.
|
optimizer. Used as a contextmanager.
|
||||||
|
|
||||||
get_gold_tuples (function): Function returning gold data
|
get_gold_tuples (function): Function returning gold data
|
||||||
|
component_cfg (dict): Config parameters for specific components.
|
||||||
**cfg: Config parameters.
|
**cfg: Config parameters.
|
||||||
RETURNS: An optimizer
|
RETURNS: An optimizer
|
||||||
"""
|
"""
|
||||||
|
@ -543,10 +553,17 @@ class Language(object):
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = create_default_optimizer(Model.ops)
|
sgd = create_default_optimizer(Model.ops)
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
|
if component_cfg is None:
|
||||||
|
component_cfg = {}
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, "begin_training"):
|
if hasattr(proc, "begin_training"):
|
||||||
|
kwargs = component_cfg.get(name, {})
|
||||||
|
kwargs.update(cfg)
|
||||||
proc.begin_training(
|
proc.begin_training(
|
||||||
get_gold_tuples, pipeline=self.pipeline, sgd=self._optimizer, **cfg
|
get_gold_tuples,
|
||||||
|
pipeline=self.pipeline,
|
||||||
|
sgd=self._optimizer,
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
return self._optimizer
|
return self._optimizer
|
||||||
|
|
||||||
|
@ -574,20 +591,27 @@ class Language(object):
|
||||||
proc._rehearsal_model = deepcopy(proc.model)
|
proc._rehearsal_model = deepcopy(proc.model)
|
||||||
return self._optimizer
|
return self._optimizer
|
||||||
|
|
||||||
def evaluate(self, docs_golds, verbose=False, batch_size=256):
|
def evaluate(
|
||||||
scorer = Scorer()
|
self, docs_golds, verbose=False, batch_size=256, scorer=None, component_cfg=None
|
||||||
|
):
|
||||||
|
if scorer is None:
|
||||||
|
scorer = Scorer()
|
||||||
docs, golds = zip(*docs_golds)
|
docs, golds = zip(*docs_golds)
|
||||||
docs = list(docs)
|
docs = list(docs)
|
||||||
golds = list(golds)
|
golds = list(golds)
|
||||||
for name, pipe in self.pipeline:
|
for name, pipe in self.pipeline:
|
||||||
|
kwargs = component_cfg.get(name, {})
|
||||||
|
kwargs.setdefault("batch_size", batch_size)
|
||||||
if not hasattr(pipe, "pipe"):
|
if not hasattr(pipe, "pipe"):
|
||||||
docs = (pipe(doc) for doc in docs)
|
docs = (pipe(doc, **kwargs) for doc in docs)
|
||||||
else:
|
else:
|
||||||
docs = pipe.pipe(docs, batch_size=batch_size)
|
docs = pipe.pipe(docs, **kwargs)
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(doc)
|
print(doc)
|
||||||
scorer.score(doc, gold, verbose=verbose)
|
kwargs = component_cfg.get("scorer", {})
|
||||||
|
kwargs.setdefault("verbose", verbose)
|
||||||
|
scorer.score(doc, gold, **kwargs)
|
||||||
return scorer
|
return scorer
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -630,6 +654,7 @@ class Language(object):
|
||||||
batch_size=1000,
|
batch_size=1000,
|
||||||
disable=[],
|
disable=[],
|
||||||
cleanup=False,
|
cleanup=False,
|
||||||
|
component_cfg=None,
|
||||||
):
|
):
|
||||||
"""Process texts as a stream, and yield `Doc` objects in order.
|
"""Process texts as a stream, and yield `Doc` objects in order.
|
||||||
|
|
||||||
|
@ -643,6 +668,8 @@ class Language(object):
|
||||||
disable (list): Names of the pipeline components to disable.
|
disable (list): Names of the pipeline components to disable.
|
||||||
cleanup (bool): If True, unneeded strings are freed,
|
cleanup (bool): If True, unneeded strings are freed,
|
||||||
to control memory use. Experimental.
|
to control memory use. Experimental.
|
||||||
|
component_cfg (dict): An optional dictionary with extra keyword arguments
|
||||||
|
for specific components.
|
||||||
YIELDS (Doc): Documents in the order of the original text.
|
YIELDS (Doc): Documents in the order of the original text.
|
||||||
|
|
||||||
EXAMPLE:
|
EXAMPLE:
|
||||||
|
@ -655,20 +682,30 @@ class Language(object):
|
||||||
texts = (tc[0] for tc in text_context1)
|
texts = (tc[0] for tc in text_context1)
|
||||||
contexts = (tc[1] for tc in text_context2)
|
contexts = (tc[1] for tc in text_context2)
|
||||||
docs = self.pipe(
|
docs = self.pipe(
|
||||||
texts, n_threads=n_threads, batch_size=batch_size, disable=disable
|
texts,
|
||||||
|
n_threads=n_threads,
|
||||||
|
batch_size=batch_size,
|
||||||
|
disable=disable,
|
||||||
|
component_cfg=component_cfg,
|
||||||
)
|
)
|
||||||
for doc, context in izip(docs, contexts):
|
for doc, context in izip(docs, contexts):
|
||||||
yield (doc, context)
|
yield (doc, context)
|
||||||
return
|
return
|
||||||
docs = (self.make_doc(text) for text in texts)
|
docs = (self.make_doc(text) for text in texts)
|
||||||
|
if component_cfg is None:
|
||||||
|
component_cfg = {}
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if name in disable:
|
if name in disable:
|
||||||
continue
|
continue
|
||||||
|
kwargs = component_cfg.get(name, {})
|
||||||
|
# Allow component_cfg to overwrite the top-level kwargs.
|
||||||
|
kwargs.setdefault("batch_size", batch_size)
|
||||||
|
kwargs.setdefault("n_threads", n_threads)
|
||||||
if hasattr(proc, "pipe"):
|
if hasattr(proc, "pipe"):
|
||||||
docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size)
|
docs = proc.pipe(docs, **kwargs)
|
||||||
else:
|
else:
|
||||||
# Apply the function, but yield the doc
|
# Apply the function, but yield the doc
|
||||||
docs = _pipe(proc, docs)
|
docs = _pipe(proc, docs, kwargs)
|
||||||
# Track weakrefs of "recent" documents, so that we can see when they
|
# Track weakrefs of "recent" documents, so that we can see when they
|
||||||
# expire from memory. When they do, we know we don't need old strings.
|
# expire from memory. When they do, we know we don't need old strings.
|
||||||
# This way, we avoid maintaining an unbounded growth in string entries
|
# This way, we avoid maintaining an unbounded growth in string entries
|
||||||
|
@ -861,7 +898,7 @@ class DisabledPipes(list):
|
||||||
self[:] = []
|
self[:] = []
|
||||||
|
|
||||||
|
|
||||||
def _pipe(func, docs):
|
def _pipe(func, docs, kwargs):
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
doc = func(doc)
|
doc = func(doc, **kwargs)
|
||||||
yield doc
|
yield doc
|
||||||
|
|
|
@ -110,7 +110,8 @@ cdef class Morphology:
|
||||||
analysis.lemma = self.lemmatize(analysis.tag.pos, token.lex.orth,
|
analysis.lemma = self.lemmatize(analysis.tag.pos, token.lex.orth,
|
||||||
self.tag_map.get(tag_str, {}))
|
self.tag_map.get(tag_str, {}))
|
||||||
self._cache.set(tag_id, token.lex.orth, analysis)
|
self._cache.set(tag_id, token.lex.orth, analysis)
|
||||||
token.lemma = analysis.lemma
|
if token.lemma == 0:
|
||||||
|
token.lemma = analysis.lemma
|
||||||
token.pos = analysis.tag.pos
|
token.pos = analysis.tag.pos
|
||||||
token.tag = analysis.tag.name
|
token.tag = analysis.tag.name
|
||||||
token.morph = analysis.tag.morph
|
token.morph = analysis.tag.morph
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.displacy import render
|
from spacy.displacy import render
|
||||||
from spacy.gold import iob_to_biluo
|
from spacy.gold import iob_to_biluo
|
||||||
|
@ -12,12 +13,14 @@ from spacy.lang.en import English
|
||||||
from ..util import add_vecs_to_vocab, get_doc
|
from ..util import add_vecs_to_vocab, get_doc
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
@pytest.mark.xfail
|
||||||
reason="The dot is now properly split off, but the prefix/suffix rules are not applied again afterwards."
|
|
||||||
"This means that the quote will still be attached to the remaining token."
|
|
||||||
)
|
|
||||||
def test_issue2070():
|
def test_issue2070():
|
||||||
"""Test that checks that a dot followed by a quote is handled appropriately."""
|
"""Test that checks that a dot followed by a quote is handled
|
||||||
|
appropriately.
|
||||||
|
"""
|
||||||
|
# Problem: The dot is now properly split off, but the prefix/suffix rules
|
||||||
|
# are not applied again afterwards. This means that the quote will still be
|
||||||
|
# attached to the remaining token.
|
||||||
nlp = English()
|
nlp = English()
|
||||||
doc = nlp('First sentence."A quoted sentence" he said ...')
|
doc = nlp('First sentence."A quoted sentence" he said ...')
|
||||||
assert len(doc) == 11
|
assert len(doc) == 11
|
||||||
|
@ -37,6 +40,26 @@ def test_issue2179():
|
||||||
assert nlp2.get_pipe("ner").labels == ("CITIZENSHIP",)
|
assert nlp2.get_pipe("ner").labels == ("CITIZENSHIP",)
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue2203(en_vocab):
|
||||||
|
"""Test that lemmas are set correctly in doc.from_array."""
|
||||||
|
words = ["I", "'ll", "survive"]
|
||||||
|
tags = ["PRP", "MD", "VB"]
|
||||||
|
lemmas = ["-PRON-", "will", "survive"]
|
||||||
|
tag_ids = [en_vocab.strings.add(tag) for tag in tags]
|
||||||
|
lemma_ids = [en_vocab.strings.add(lemma) for lemma in lemmas]
|
||||||
|
doc = Doc(en_vocab, words=words)
|
||||||
|
# Work around lemma corrpution problem and set lemmas after tags
|
||||||
|
doc.from_array("TAG", numpy.array(tag_ids, dtype="uint64"))
|
||||||
|
doc.from_array("LEMMA", numpy.array(lemma_ids, dtype="uint64"))
|
||||||
|
assert [t.tag_ for t in doc] == tags
|
||||||
|
assert [t.lemma_ for t in doc] == lemmas
|
||||||
|
# We need to serialize both tag and lemma, since this is what causes the bug
|
||||||
|
doc_array = doc.to_array(["TAG", "LEMMA"])
|
||||||
|
new_doc = Doc(doc.vocab, words=words).from_array(["TAG", "LEMMA"], doc_array)
|
||||||
|
assert [t.tag_ for t in new_doc] == tags
|
||||||
|
assert [t.lemma_ for t in new_doc] == lemmas
|
||||||
|
|
||||||
|
|
||||||
def test_issue2219(en_vocab):
|
def test_issue2219(en_vocab):
|
||||||
vectors = [("a", [1, 2, 3]), ("letter", [4, 5, 6])]
|
vectors = [("a", [1, 2, 3]), ("letter", [4, 5, 6])]
|
||||||
add_vecs_to_vocab(en_vocab, vectors)
|
add_vecs_to_vocab(en_vocab, vectors)
|
||||||
|
|
|
@ -763,17 +763,18 @@ cdef class Doc:
|
||||||
attr_ids[i] = attr_id
|
attr_ids[i] = attr_id
|
||||||
if len(array.shape) == 1:
|
if len(array.shape) == 1:
|
||||||
array = array.reshape((array.size, 1))
|
array = array.reshape((array.size, 1))
|
||||||
|
# Do TAG first. This lets subsequent loop override stuff like POS, LEMMA
|
||||||
|
if TAG in attrs:
|
||||||
|
col = attrs.index(TAG)
|
||||||
|
for i in range(length):
|
||||||
|
if array[i, col] != 0:
|
||||||
|
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
|
||||||
# Now load the data
|
# Now load the data
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
token = &self.c[i]
|
token = &self.c[i]
|
||||||
for j in range(n_attrs):
|
for j in range(n_attrs):
|
||||||
Token.set_struct_attr(token, attr_ids[j], array[i, j])
|
if attr_ids[j] != TAG:
|
||||||
# Auxiliary loading logic
|
Token.set_struct_attr(token, attr_ids[j], array[i, j])
|
||||||
for col, attr_id in enumerate(attrs):
|
|
||||||
if attr_id == TAG:
|
|
||||||
for i in range(length):
|
|
||||||
if array[i, col] != 0:
|
|
||||||
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
|
|
||||||
# Set flags
|
# Set flags
|
||||||
self.is_parsed = bool(self.is_parsed or HEAD in attrs or DEP in attrs)
|
self.is_parsed = bool(self.is_parsed or HEAD in attrs or DEP in attrs)
|
||||||
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
||||||
|
|
|
@ -91,13 +91,14 @@ multiprocessing.
|
||||||
> assert doc.is_parsed
|
> assert doc.is_parsed
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ------------ | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------------------------------------- | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `texts` | - | A sequence of unicode objects. |
|
| `texts` | - | A sequence of unicode objects. |
|
||||||
| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
|
| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
|
||||||
| `batch_size` | int | The number of texts to buffer. |
|
| `batch_size` | int | The number of texts to buffer. |
|
||||||
| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||||
| **YIELDS** | `Doc` | Documents in the order of the original text. |
|
| `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
|
||||||
|
| **YIELDS** | `Doc` | Documents in the order of the original text. |
|
||||||
|
|
||||||
## Language.update {#update tag="method"}
|
## Language.update {#update tag="method"}
|
||||||
|
|
||||||
|
@ -112,13 +113,14 @@ Update the models in the pipeline.
|
||||||
> nlp.update([doc], [gold], drop=0.5, sgd=optimizer)
|
> nlp.update([doc], [gold], drop=0.5, sgd=optimizer)
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ----------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------------------------------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. |
|
| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. |
|
||||||
| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
|
| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
|
||||||
| `drop` | float | The dropout rate. |
|
| `drop` | float | The dropout rate. |
|
||||||
| `sgd` | callable | An optimizer. |
|
| `sgd` | callable | An optimizer. |
|
||||||
| **RETURNS** | dict | Results from the update. |
|
| `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
|
||||||
|
| **RETURNS** | dict | Results from the update. |
|
||||||
|
|
||||||
## Language.begin_training {#begin_training tag="method"}
|
## Language.begin_training {#begin_training tag="method"}
|
||||||
|
|
||||||
|
@ -130,11 +132,12 @@ Allocate models, pre-process training data and acquire an optimizer.
|
||||||
> optimizer = nlp.begin_training(gold_tuples)
|
> optimizer = nlp.begin_training(gold_tuples)
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ------------- | -------- | ---------------------------- |
|
| -------------------------------------------- | -------- | ---------------------------------------------------------------------------- |
|
||||||
| `gold_tuples` | iterable | Gold-standard training data. |
|
| `gold_tuples` | iterable | Gold-standard training data. |
|
||||||
| `**cfg` | - | Config parameters. |
|
| `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
|
||||||
| **RETURNS** | callable | An optimizer. |
|
| `**cfg` | - | Config parameters (sent to all components). |
|
||||||
|
| **RETURNS** | callable | An optimizer. |
|
||||||
|
|
||||||
## Language.use_params {#use_params tag="contextmanager, method"}
|
## Language.use_params {#use_params tag="contextmanager, method"}
|
||||||
|
|
||||||
|
|
|
@ -283,7 +283,7 @@ from pathlib import Path
|
||||||
nlp = spacy.load("en_core_web_sm")
|
nlp = spacy.load("en_core_web_sm")
|
||||||
sentences = [u"This is an example.", u"This is another one."]
|
sentences = [u"This is an example.", u"This is another one."]
|
||||||
for sent in sentences:
|
for sent in sentences:
|
||||||
doc = nlp(sentence)
|
doc = nlp(sent)
|
||||||
svg = displacy.render(doc, style="dep")
|
svg = displacy.render(doc, style="dep")
|
||||||
file_name = '-'.join([w.text for w in doc if not w.is_punct]) + ".svg"
|
file_name = '-'.join([w.text for w in doc if not w.is_punct]) + ".svg"
|
||||||
output_path = Path("/images/" + file_name)
|
output_path = Path("/images/" + file_name)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user