Merge branch 'master' into feature/el-framework

This commit is contained in:
Sofie 2019-03-26 11:00:02 +01:00 committed by GitHub
commit a4a6bfa4e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 666 additions and 284 deletions

106
.github/contributors/wannaphongcom.md vendored Normal file
View File

@ -0,0 +1,106 @@
# spaCy contributor agreement
This spaCy Contributor Agreement (**"SCA"**) is based on the
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
The SCA applies to any contribution that you make to any product or project
managed by us (the **"project"**), and sets out the intellectual property rights
you grant to us in the contributed materials. The term **"us"** shall mean
[ExplosionAI UG (haftungsbeschränkt)](https://explosion.ai/legal). The term
**"you"** shall mean the person or entity identified below.
If you agree to be bound by these terms, fill in the information requested
below and include the filled-in version with your first pull request, under the
folder [`.github/contributors/`](/.github/contributors/). The name of the file
should be your GitHub username, with the extension `.md`. For example, the user
example_user would create the file `.github/contributors/example_user.md`.
Read this agreement carefully before signing. These terms and conditions
constitute a binding legal agreement.
## Contributor Agreement
1. The term "contribution" or "contributed materials" means any source code,
object code, patch, tool, sample, graphic, specification, manual,
documentation, or any other material posted or submitted by you to the project.
2. With respect to any worldwide copyrights, or copyright applications and
registrations, in your contribution:
* you hereby assign to us joint ownership, and to the extent that such
assignment is or becomes invalid, ineffective or unenforceable, you hereby
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
royalty-free, unrestricted license to exercise all rights under those
copyrights. This includes, at our option, the right to sublicense these same
rights to third parties through multiple levels of sublicensees or other
licensing arrangements;
* you agree that each of us can do all things in relation to your
contribution as if each of us were the sole owners, and if one of us makes
a derivative work of your contribution, the one who makes the derivative
work (or has it made will be the sole owner of that derivative work;
* you agree that you will not assert any moral rights in your contribution
against us, our licensees or transferees;
* you agree that we may register a copyright in your contribution and
exercise all ownership rights associated with it; and
* you agree that neither of us has any duty to consult with, obtain the
consent of, pay or render an accounting to the other for any use or
distribution of your contribution.
3. With respect to any patents you own, or that you can license without payment
to any third party, you hereby grant to us a perpetual, irrevocable,
non-exclusive, worldwide, no-charge, royalty-free license to:
* make, have made, use, sell, offer to sell, import, and otherwise transfer
your contribution in whole or in part, alone or in combination with or
included in any product, work or materials arising out of the project to
which your contribution was submitted, and
* at our option, to sublicense these same rights to third parties through
multiple levels of sublicensees or other licensing arrangements.
4. Except as set out above, you keep all right, title, and interest in your
contribution. The rights that you grant to us under these terms are effective
on the date you first submitted a contribution to us, even if your submission
took place before the date you sign these terms.
5. You covenant, represent, warrant and agree that:
* Each contribution that you submit is and shall be an original work of
authorship and you can legally grant the rights set out in this SCA;
* to the best of your knowledge, each contribution will not violate any
third party's copyrights, trademarks, patents, or other intellectual
property rights; and
* each contribution shall be in compliance with U.S. export control laws and
other applicable export and import laws. You agree to notify us if you
become aware of any circumstance which would make any of the foregoing
representations inaccurate in any respect. We may publicly disclose your
participation in the project, including the fact that you have signed the SCA.
6. This SCA is governed by the laws of the State of California and applicable
U.S. Federal law. Any choice of law rules will not apply.
7. Please place an “x” on one of the applicable statement below. Please do NOT
mark both statements:
* [x] I am signing on behalf of myself as an individual and no other person
or entity, including my employer, has or will have rights with respect my
contributions.
* [ ] I am signing on behalf of my employer or a legal entity and I have the
actual authority to contractually bind that entity.
## Contributor Details
| Field | Entry |
|------------------------------- | -------------------- |
| Name | Wannaphong Phatthiyaphaibun |
| Company name (if applicable) | |
| Title or role (if applicable) | |
| Date | 25-3-2019 |
| GitHub username | wannaphongcom |
| Website (optional) | |

View File

@ -43,7 +43,11 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
# nlp.create_pipe works for built-ins that are registered with spaCy # nlp.create_pipe works for built-ins that are registered with spaCy
if "textcat" not in nlp.pipe_names: if "textcat" not in nlp.pipe_names:
textcat = nlp.create_pipe( textcat = nlp.create_pipe(
"textcat", config={"architecture": "simple_cnn", "exclusive_classes": True} "textcat",
config={
"exclusive_classes": True,
"architecture": "simple_cnn",
}
) )
nlp.add_pipe(textcat, last=True) nlp.add_pipe(textcat, last=True)
# otherwise, get it, so we can add labels to it # otherwise, get it, so we can add labels to it
@ -56,7 +60,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
# load the IMDB dataset # load the IMDB dataset
print("Loading IMDB data...") print("Loading IMDB data...")
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts) (train_texts, train_cats), (dev_texts, dev_cats) = load_data()
train_texts = train_texts[:n_texts]
train_cats = train_cats[:n_texts]
print( print(
"Using {} examples ({} training, {} evaluation)".format( "Using {} examples ({} training, {} evaluation)".format(
n_texts, len(train_texts), len(dev_texts) n_texts, len(train_texts), len(dev_texts)
@ -73,10 +79,12 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
textcat.model.tok2vec.from_bytes(file_.read()) textcat.model.tok2vec.from_bytes(file_.read())
print("Training the model...") print("Training the model...")
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F")) print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
batch_sizes = compounding(4.0, 32.0, 1.001)
for i in range(n_iter): for i in range(n_iter):
losses = {} losses = {}
# batch up the examples using spaCy's minibatch # batch up the examples using spaCy's minibatch
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) random.shuffle(train_data)
batches = minibatch(train_data, size=batch_sizes)
for batch in batches: for batch in batches:
texts, annotations = zip(*batch) texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses) nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)

View File

@ -43,8 +43,9 @@ redirects = [
{from = "/usage/lightning-tour", to = "/usage/spacy-101#lightning-tour"}, {from = "/usage/lightning-tour", to = "/usage/spacy-101#lightning-tour"},
{from = "/usage/linguistic-features#rule-based-matching", to = "/usage/rule-based-matching"}, {from = "/usage/linguistic-features#rule-based-matching", to = "/usage/rule-based-matching"},
{from = "/models/comparison", to = "/models"}, {from = "/models/comparison", to = "/models"},
{from = "/api/#section-cython", to = "/api/cython"}, {from = "/api/#section-cython", to = "/api/cython", force = true},
{from = "/api/#cython", to = "/api/cython"}, {from = "/api/#cython", to = "/api/cython", force = true},
{from = "/api/sentencesegmenter", to="/api/sentencizer"},
{from = "/universe", to = "/universe/project/:id", query = {id = ":id"}, force = true}, {from = "/universe", to = "/universe/project/:id", query = {id = ":id"}, force = true},
{from = "/universe", to = "/universe/category/:category", query = {category = ":category"}, force = true}, {from = "/universe", to = "/universe/category/:category", query = {category = ":category"}, force = true},
] ]

View File

@ -81,18 +81,6 @@ def _zero_init(model):
return model return model
@layerize
def _preprocess_doc(docs, drop=0.0):
keys = [doc.to_array(LOWER) for doc in docs]
# The dtype here matches what thinc is expecting -- which differs per
# platform (by int definition). This should be fixed once the problem
# is fixed on Thinc's side.
lengths = numpy.array([arr.shape[0] for arr in keys], dtype=numpy.int_)
keys = numpy.concatenate(keys)
vals = numpy.zeros(keys.shape, dtype='f')
return (keys, vals, lengths), None
def with_cpu(ops, model): def with_cpu(ops, model):
"""Wrap a model that should run on CPU, transferring inputs and outputs """Wrap a model that should run on CPU, transferring inputs and outputs
as necessary.""" as necessary."""
@ -133,20 +121,31 @@ def _to_device(ops, X):
return ops.asarray(X) return ops.asarray(X)
@layerize class extract_ngrams(Model):
def _preprocess_doc_bigrams(docs, drop=0.0): def __init__(self, ngram_size, attr=LOWER):
unigrams = [doc.to_array(LOWER) for doc in docs] Model.__init__(self)
ops = Model.ops self.ngram_size = ngram_size
bigrams = [ops.ngrams(2, doc_unis) for doc_unis in unigrams] self.attr = attr
keys = [ops.xp.concatenate(feats) for feats in zip(unigrams, bigrams)]
keys, vals = zip(*[ops.xp.unique(k, return_counts=True) for k in keys]) def begin_update(self, docs, drop=0.0):
# The dtype here matches what thinc is expecting -- which differs per batch_keys = []
# platform (by int definition). This should be fixed once the problem batch_vals = []
# is fixed on Thinc's side. for doc in docs:
lengths = ops.asarray([arr.shape[0] for arr in keys], dtype=numpy.int_) unigrams = doc.to_array([self.attr])
keys = ops.xp.concatenate(keys) ngrams = [unigrams]
vals = ops.asarray(ops.xp.concatenate(vals), dtype="f") for n in range(2, self.ngram_size + 1):
return (keys, vals, lengths), None ngrams.append(self.ops.ngrams(n, unigrams))
keys = self.ops.xp.concatenate(ngrams)
keys, vals = self.ops.xp.unique(keys, return_counts=True)
batch_keys.append(keys)
batch_vals.append(vals)
# The dtype here matches what thinc is expecting -- which differs per
# platform (by int definition). This should be fixed once the problem
# is fixed on Thinc's side.
lengths = self.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_)
batch_keys = self.ops.xp.concatenate(batch_keys)
batch_vals = self.ops.asarray(self.ops.xp.concatenate(batch_vals), dtype="f")
return (batch_keys, batch_vals, lengths), None
@describe.on_data( @describe.on_data(
@ -486,16 +485,6 @@ def zero_init(model):
return model return model
@layerize
def preprocess_doc(docs, drop=0.0):
keys = [doc.to_array([LOWER]) for doc in docs]
ops = Model.ops
lengths = ops.asarray([arr.shape[0] for arr in keys])
keys = ops.xp.concatenate(keys)
vals = ops.allocate(keys.shape[0]) + 1
return (keys, vals, lengths), None
def getitem(i): def getitem(i):
def getitem_fwd(X, drop=0.0): def getitem_fwd(X, drop=0.0):
return X[i], None return X[i], None
@ -602,10 +591,8 @@ def build_text_classifier(nr_class, width=64, **cfg):
>> zero_init(Affine(nr_class, width, drop_factor=0.0)) >> zero_init(Affine(nr_class, width, drop_factor=0.0))
) )
linear_model = ( linear_model = build_bow_text_classifier(
_preprocess_doc nr_class, ngram_size=cfg.get("ngram_size", 1), exclusive_classes=False)
>> with_cpu(Model.ops, LinearModel(nr_class))
)
if cfg.get('exclusive_classes'): if cfg.get('exclusive_classes'):
output_layer = Softmax(nr_class, nr_class * 2) output_layer = Softmax(nr_class, nr_class * 2)
else: else:
@ -623,6 +610,33 @@ def build_text_classifier(nr_class, width=64, **cfg):
return model return model
def build_bow_text_classifier(nr_class, ngram_size=1, exclusive_classes=False,
no_output_layer=False, **cfg):
with Model.define_operators({">>": chain}):
model = (
extract_ngrams(ngram_size, attr=ORTH)
>> with_cpu(Model.ops,
LinearModel(nr_class)
)
)
if not no_output_layer:
model = model >> (cpu_softmax if exclusive_classes else logistic)
model.nO = nr_class
return model
@layerize
def cpu_softmax(X, drop=0.):
ops = NumpyOps()
Y = ops.softmax(X)
def cpu_softmax_backward(dY, sgd=None):
return dY
return ops.softmax(X), cpu_softmax_backward
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg): def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
""" """
Build a simple CNN text classifier, given a token-to-vector model as inputs. Build a simple CNN text classifier, given a token-to-vector model as inputs.

View File

@ -4,7 +4,7 @@
# fmt: off # fmt: off
__title__ = "spacy" __title__ = "spacy"
__version__ = "2.1.2" __version__ = "2.1.3"
__summary__ = "Industrial-strength Natural Language Processing (NLP) with Python and Cython" __summary__ = "Industrial-strength Natural Language Processing (NLP) with Python and Cython"
__uri__ = "https://spacy.io" __uri__ = "https://spacy.io"
__author__ = "Explosion AI" __author__ = "Explosion AI"

View File

@ -1,7 +1,7 @@
# encoding: utf8 # encoding: utf8
from __future__ import unicode_literals from __future__ import unicode_literals
from ...symbols import POS, NOUN, PRON, ADJ, ADV, INTJ, PROPN, DET, NUM, AUX from ...symbols import POS, NOUN, PRON, ADJ, ADV, INTJ, PROPN, DET, NUM, AUX,VERB
from ...symbols import ADP, CCONJ, PART, PUNCT, SPACE, SCONJ from ...symbols import ADP, CCONJ, PART, PUNCT, SPACE, SCONJ
# Source: Korakot Chaovavanich # Source: Korakot Chaovavanich
@ -16,6 +16,9 @@ TAG_MAP = {
"CMTR": {POS: NOUN}, "CMTR": {POS: NOUN},
"CFQC": {POS: NOUN}, "CFQC": {POS: NOUN},
"CVBL": {POS: NOUN}, "CVBL": {POS: NOUN},
# VERB
"VACT":{POS:VERB},
"VSTA":{POS:VERB},
# PRON # PRON
"PRON": {POS: PRON}, "PRON": {POS: PRON},
"NPRP": {POS: PRON}, "NPRP": {POS: PRON},
@ -78,6 +81,7 @@ TAG_MAP = {
"EAFF": {POS: PART}, "EAFF": {POS: PART},
"AITT": {POS: PART}, "AITT": {POS: PART},
"NEG": {POS: PART}, "NEG": {POS: PART},
"EITT": {POS: PART},
# PUNCT # PUNCT
"PUNCT": {POS: PUNCT}, "PUNCT": {POS: PUNCT},
"PUNC": {POS: PUNCT}, "PUNC": {POS: PUNCT},

View File

@ -15,7 +15,7 @@ from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer, EntityLinker from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer, EntityLinker
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
from .pipeline import EntityRuler from .pipeline import EntityRuler
from .compat import izip, basestring_ from .compat import izip, basestring_
@ -120,7 +120,7 @@ class Language(object):
"entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg), "entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg),
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg), "similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg), "textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
"sentencizer": lambda nlp, **cfg: SentenceSegmenter(nlp.vocab, **cfg), "sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
"merge_noun_chunks": lambda nlp, **cfg: merge_noun_chunks, "merge_noun_chunks": lambda nlp, **cfg: merge_noun_chunks,
"merge_entities": lambda nlp, **cfg: merge_entities, "merge_entities": lambda nlp, **cfg: merge_entities,
"merge_subtokens": lambda nlp, **cfg: merge_subtokens, "merge_subtokens": lambda nlp, **cfg: merge_subtokens,

View File

@ -2,7 +2,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
from .pipes import TextCategorizer, Tensorizer, Pipe from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
from .entityruler import EntityRuler from .entityruler import EntityRuler
from .hooks import SentenceSegmenter, SimilarityHook from .hooks import SentenceSegmenter, SimilarityHook
from .functions import merge_entities, merge_noun_chunks, merge_subtokens from .functions import merge_entities, merge_noun_chunks, merge_subtokens
@ -16,6 +16,7 @@ __all__ = [
"Tensorizer", "Tensorizer",
"Pipe", "Pipe",
"EntityRuler", "EntityRuler",
"Sentencizer",
"SentenceSegmenter", "SentenceSegmenter",
"SimilarityHook", "SimilarityHook",
"merge_entities", "merge_entities",

View File

@ -191,7 +191,7 @@ class EntityRuler(object):
**kwargs: Other config paramters, mostly for consistency. **kwargs: Other config paramters, mostly for consistency.
RETURNS (EntityRuler): The loaded entity ruler. RETURNS (EntityRuler): The loaded entity ruler.
DOCS: https://spacy.io/api/entityruler DOCS: https://spacy.io/api/entityruler#to_disk
""" """
path = ensure_path(path) path = ensure_path(path)
path = path.with_suffix(".jsonl") path = path.with_suffix(".jsonl")

View File

@ -15,8 +15,6 @@ class SentenceSegmenter(object):
initialization, or assign a new strategy to the .strategy attribute. initialization, or assign a new strategy to the .strategy attribute.
Sentence detection strategies should be generators that take `Doc` objects Sentence detection strategies should be generators that take `Doc` objects
and yield `Span` objects for each sentence. and yield `Span` objects for each sentence.
DOCS: https://spacy.io/api/sentencesegmenter
""" """
name = "sentencizer" name = "sentencizer"
@ -35,12 +33,12 @@ class SentenceSegmenter(object):
def split_on_punct(doc): def split_on_punct(doc):
start = 0 start = 0
seen_period = False seen_period = False
for i, word in enumerate(doc): for i, token in enumerate(doc):
if seen_period and not word.is_punct: if seen_period and not token.is_punct:
yield doc[start : word.i] yield doc[start : token.i]
start = word.i start = token.i
seen_period = False seen_period = False
elif word.text in [".", "!", "?"]: elif token.text in [".", "!", "?"]:
seen_period = True seen_period = True
if start < len(doc): if start < len(doc):
yield doc[start : len(doc)] yield doc[start : len(doc)]

View File

@ -25,6 +25,7 @@ from ..attrs import POS, ID
from ..parts_of_speech import X from ..parts_of_speech import X
from .._ml import Tok2Vec, build_tagger_model from .._ml import Tok2Vec, build_tagger_model
from .._ml import build_text_classifier, build_simple_cnn_text_classifier from .._ml import build_text_classifier, build_simple_cnn_text_classifier
from .._ml import build_bow_text_classifier
from .._ml import link_vectors_to_models, zero_init, flatten from .._ml import link_vectors_to_models, zero_init, flatten
from .._ml import masked_language_model, create_default_optimizer from .._ml import masked_language_model, create_default_optimizer
from ..errors import Errors, TempErrors from ..errors import Errors, TempErrors
@ -876,6 +877,8 @@ class TextCategorizer(Pipe):
if cfg.get("architecture") == "simple_cnn": if cfg.get("architecture") == "simple_cnn":
tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg) tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg)
return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg) return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg)
elif cfg.get("architecture") == "bow":
return build_bow_text_classifier(nr_class, **cfg)
else: else:
return build_text_classifier(nr_class, **cfg) return build_text_classifier(nr_class, **cfg)
@ -1107,4 +1110,90 @@ class EntityLinker(Pipe):
pass pass
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker"] class Sentencizer(object):
"""Segment the Doc into sentences using a rule-based strategy.
DOCS: https://spacy.io/api/sentencizer
"""
name = "sentencizer"
default_punct_chars = [".", "!", "?"]
def __init__(self, punct_chars=None, **kwargs):
"""Initialize the sentencizer.
punct_chars (list): Punctuation characters to split on. Will be
serialized with the nlp object.
RETURNS (Sentencizer): The sentencizer component.
DOCS: https://spacy.io/api/sentencizer#init
"""
self.punct_chars = punct_chars or self.default_punct_chars
def __call__(self, doc):
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
doc (Doc): The document to process.
RETURNS (Doc): The processed Doc.
DOCS: https://spacy.io/api/sentencizer#call
"""
start = 0
seen_period = False
for i, token in enumerate(doc):
is_in_punct_chars = token.text in self.punct_chars
token.is_sent_start = i == 0
if seen_period and not token.is_punct and not is_in_punct_chars:
doc[start].is_sent_start = True
start = token.i
seen_period = False
elif is_in_punct_chars:
seen_period = True
if start < len(doc):
doc[start].is_sent_start = True
return doc
def to_bytes(self, **kwargs):
"""Serialize the sentencizer to a bytestring.
RETURNS (bytes): The serialized object.
DOCS: https://spacy.io/api/sentencizer#to_bytes
"""
return srsly.msgpack_dumps({"punct_chars": self.punct_chars})
def from_bytes(self, bytes_data, **kwargs):
"""Load the sentencizer from a bytestring.
bytes_data (bytes): The data to load.
returns (Sentencizer): The loaded object.
DOCS: https://spacy.io/api/sentencizer#from_bytes
"""
cfg = srsly.msgpack_loads(bytes_data)
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
return self
def to_disk(self, path, exclude=tuple(), **kwargs):
"""Serialize the sentencizer to disk.
DOCS: https://spacy.io/api/sentencizer#to_disk
"""
path = util.ensure_path(path)
path = path.with_suffix(".json")
srsly.write_json(path, {"punct_chars": self.punct_chars})
def from_disk(self, path, exclude=tuple(), **kwargs):
"""Load the sentencizer from disk.
DOCS: https://spacy.io/api/sentencizer#from_disk
"""
path = util.ensure_path(path)
path = path.with_suffix(".json")
cfg = srsly.read_json(path)
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
return self
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer"]

View File

@ -574,11 +574,12 @@ cdef class Parser:
cfg.setdefault('min_action_freq', 30) cfg.setdefault('min_action_freq', 30)
actions = self.moves.get_actions(gold_parses=get_gold_tuples(), actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
min_freq=cfg.get('min_action_freq', 30)) min_freq=cfg.get('min_action_freq', 30))
previous_labels = dict(self.moves.labels) for action, labels in self.moves.labels.items():
actions.setdefault(action, {})
for label, freq in labels.items():
if label not in actions[action]:
actions[action][label] = freq
self.moves.initialize_actions(actions) self.moves.initialize_actions(actions)
for action, label_freqs in previous_labels.items():
for label in label_freqs:
self.moves.add_action(action, label)
cfg.setdefault('token_vector_width', 96) cfg.setdefault('token_vector_width', 96)
if self.model is True: if self.model is True:
self.model, cfg = self.Model(self.moves.n_moves, **cfg) self.model, cfg = self.Model(self.moves.n_moves, **cfg)

View File

@ -8,7 +8,8 @@ from spacy.attrs import NORM
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.pipeline import DependencyParser from spacy.pipeline import DependencyParser, EntityRecognizer
from spacy.util import fix_random_seed
@pytest.fixture @pytest.fixture
@ -19,18 +20,6 @@ def vocab():
@pytest.fixture @pytest.fixture
def parser(vocab): def parser(vocab):
parser = DependencyParser(vocab) parser = DependencyParser(vocab)
parser.cfg["token_vector_width"] = 8
parser.cfg["hidden_width"] = 30
parser.cfg["hist_size"] = 0
parser.add_label("left")
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update([doc], [gold], sgd=sgd, losses=losses)
return parser return parser
@ -38,25 +27,23 @@ def test_init_parser(parser):
pass pass
# TODO: This is flakey, because it depends on what the parser first learns. def _train_parser(parser):
# TODO: This now seems to be implicated in segfaults. Not sure what's up! fix_random_seed(1)
@pytest.mark.skip parser.add_label("left")
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
for i in range(5):
losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update([doc], [gold], sgd=sgd, losses=losses)
return parser
def test_add_label(parser): def test_add_label(parser):
doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) parser = _train_parser(parser)
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == "left"
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
parser.add_label("right") parser.add_label("right")
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == "left"
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
sgd = Adam(NumpyOps(), 0.001) sgd = Adam(NumpyOps(), 0.001)
for i in range(10): for i in range(10):
losses = {} losses = {}
@ -69,3 +56,15 @@ def test_add_label(parser):
doc = parser(doc) doc = parser(doc)
assert doc[0].dep_ == "right" assert doc[0].dep_ == "right"
assert doc[2].dep_ == "left" assert doc[2].dep_ == "left"
def test_add_label_deserializes_correctly():
ner1 = EntityRecognizer(Vocab())
ner1.add_label("C")
ner1.add_label("B")
ner1.add_label("A")
ner1.begin_training([])
ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes())
assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves):
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)

View File

@ -0,0 +1,87 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.pipeline import Sentencizer
from spacy.tokens import Doc
def test_sentencizer(en_vocab):
doc = Doc(en_vocab, words=["Hello", "!", "This", "is", "a", "test", "."])
sentencizer = Sentencizer()
doc = sentencizer(doc)
assert doc.is_sentenced
sent_starts = [t.is_sent_start for t in doc]
assert sent_starts == [True, False, True, False, False, False, False]
assert len(list(doc.sents)) == 2
@pytest.mark.parametrize(
"words,sent_starts,n_sents",
[
# The expected result here is that the duplicate punctuation gets merged
# onto the same sentence and no one-token sentence is created for them.
(
["Hello", "!", ".", "Test", ".", ".", "ok"],
[True, False, False, True, False, False, True],
3,
),
# We also want to make sure ¡ and ¿ aren't treated as sentence end
# markers, even though they're punctuation
(
["¡", "Buen", "día", "!", "Hola", ",", "¿", "qué", "tal", "?"],
[True, False, False, False, True, False, False, False, False, False],
2,
),
# The Token.is_punct check ensures that quotes are handled as well
(
['"', "Nice", "!", '"', "I", "am", "happy", "."],
[True, False, False, False, True, False, False, False],
2,
),
],
)
def test_sentencizer_complex(en_vocab, words, sent_starts, n_sents):
doc = Doc(en_vocab, words=words)
sentencizer = Sentencizer()
doc = sentencizer(doc)
assert doc.is_sentenced
assert [t.is_sent_start for t in doc] == sent_starts
assert len(list(doc.sents)) == n_sents
@pytest.mark.parametrize(
"punct_chars,words,sent_starts,n_sents",
[
(
["~", "?"],
["Hello", "world", "~", "A", ".", "B", "."],
[True, False, False, True, False, False, False],
2,
),
# Even thought it's not common, the punct_chars should be able to
# handle any tokens
(
[".", "ö"],
["Hello", ".", "Test", "ö", "Ok", "."],
[True, False, True, False, True, False],
3,
),
],
)
def test_sentencizer_custom_punct(en_vocab, punct_chars, words, sent_starts, n_sents):
doc = Doc(en_vocab, words=words)
sentencizer = Sentencizer(punct_chars=punct_chars)
doc = sentencizer(doc)
assert doc.is_sentenced
assert [t.is_sent_start for t in doc] == sent_starts
assert len(list(doc.sents)) == n_sents
def test_sentencizer_serialize_bytes(en_vocab):
punct_chars = [".", "~", "+"]
sentencizer = Sentencizer(punct_chars=punct_chars)
assert sentencizer.punct_chars == punct_chars
bytes_data = sentencizer.to_bytes()
new_sentencizer = Sentencizer().from_bytes(bytes_data)
assert new_sentencizer.punct_chars == punct_chars

View File

@ -0,0 +1,22 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.lang.en import English
from spacy.tokens import Doc
def test_issue3468():
"""Test that sentence boundaries are set correctly so Doc.is_sentenced can
be restored after serialization."""
nlp = English()
nlp.add_pipe(nlp.create_pipe("sentencizer"))
doc = nlp("Hello world")
assert doc[0].is_sent_start
assert doc.is_sentenced
assert len(list(doc.sents)) == 1
doc_bytes = doc.to_bytes()
new_doc = Doc(nlp.vocab).from_bytes(doc_bytes)
assert new_doc[0].is_sent_start
assert new_doc.is_sentenced
assert len(list(new_doc.sents)) == 1

View File

@ -230,7 +230,7 @@ cdef class Doc:
defined as having at least one of the following: defined as having at least one of the following:
a) An entry "sents" in doc.user_hooks"; a) An entry "sents" in doc.user_hooks";
b) sent.is_parsed is set to True; b) Doc.is_parsed is set to True;
c) At least one token other than the first where sent_start is not None. c) At least one token other than the first where sent_start is not None.
""" """
if "sents" in self.user_hooks: if "sents" in self.user_hooks:

View File

@ -441,6 +441,7 @@ cdef class Token:
property sent_start: property sent_start:
def __get__(self): def __get__(self):
"""Deprecated: use Token.is_sent_start instead."""
# Raising a deprecation warning here causes errors for autocomplete # Raising a deprecation warning here causes errors for autocomplete
# Handle broken backwards compatibility case: doc[0].sent_start # Handle broken backwards compatibility case: doc[0].sent_start
# was False. # was False.

View File

@ -1,78 +0,0 @@
---
title: SentenceSegmenter
tag: class
source: spacy/pipeline/hooks.py
---
A simple spaCy hook, to allow custom sentence boundary detection logic that
doesn't require the dependency parse. By default, sentence segmentation is
performed by the [`DependencyParser`](/api/dependencyparser), so the
`SentenceSegmenter` lets you implement a simpler, rule-based strategy that
doesn't require a statistical model to be loaded. The component is also
available via the string name `"sentencizer"`. After initialization, it is
typically added to the processing pipeline using
[`nlp.add_pipe`](/api/language#add_pipe).
## SentenceSegmenter.\_\_init\_\_ {#init tag="method"}
Initialize the sentence segmenter. To change the sentence boundary detection
strategy, pass a generator function `strategy` on initialization, or assign a
new strategy to the `.strategy` attribute. Sentence detection strategies should
be generators that take `Doc` objects and yield `Span` objects for each
sentence.
> #### Example
>
> ```python
> # Construction via create_pipe
> sentencizer = nlp.create_pipe("sentencizer")
>
> # Construction from class
> from spacy.pipeline import SentenceSegmenter
> sentencizer = SentenceSegmenter(nlp.vocab)
> ```
| Name | Type | Description |
| ----------- | ------------------- | ----------------------------------------------------------- |
| `vocab` | `Vocab` | The shared vocabulary. |
| `strategy` | unicode / callable | The segmentation strategy to use. Defaults to `"on_punct"`. |
| **RETURNS** | `SentenceSegmenter` | The newly constructed object. |
## SentenceSegmenter.\_\_call\_\_ {#call tag="method"}
Apply the sentence segmenter on a `Doc`. Typically, this happens automatically
after the component has been added to the pipeline using
[`nlp.add_pipe`](/api/language#add_pipe).
> #### Example
>
> ```python
> from spacy.lang.en import English
>
> nlp = English()
> sentencizer = nlp.create_pipe("sentencizer")
> nlp.add_pipe(sentencizer)
> doc = nlp(u"This is a sentence. This is another sentence.")
> assert list(doc.sents) == 2
> ```
| Name | Type | Description |
| ----------- | ----- | ------------------------------------------------------------ |
| `doc` | `Doc` | The `Doc` object to process, e.g. the `Doc` in the pipeline. |
| **RETURNS** | `Doc` | The modified `Doc` with added sentence boundaries. |
## SentenceSegmenter.split_on_punct {#split_on_punct tag="staticmethod"}
Split the `Doc` on punctuation characters `.`, `!` and `?`. This is the default
strategy used by the `SentenceSegmenter.`
| Name | Type | Description |
| ---------- | ------ | ------------------------------ |
| `doc` | `Doc` | The `Doc` object to process. |
| **YIELDS** | `Span` | The sentences in the document. |
## Attributes {#attributes}
| Name | Type | Description |
| ---------- | -------- | ------------------------------------------------------------------- |
| `strategy` | callable | The segmentation strategy. Can be overwritten after initialization. |

View File

@ -0,0 +1,136 @@
---
title: Sentencizer
tag: class
source: spacy/pipeline/pipes.pyx
---
A simple pipeline component, to allow custom sentence boundary detection logic
that doesn't require the dependency parse. By default, sentence segmentation is
performed by the [`DependencyParser`](/api/dependencyparser), so the
`Sentencizer` lets you implement a simpler, rule-based strategy that doesn't
require a statistical model to be loaded. The component is also available via
the string name `"sentencizer"`. After initialization, it is typically added to
the processing pipeline using [`nlp.add_pipe`](/api/language#add_pipe).
<Infobox title="Important note" variant="warning">
Compared to the previous `SentenceSegmenter` class, the `Sentencizer` component
doesn't add a hook to `doc.user_hooks["sents"]`. Instead, it iterates over the
tokens in the `Doc` and sets the `Token.is_sent_start` property. The
`SentenceSegmenter` is still available if you import it directly:
```python
from spacy.pipeline import SentenceSegmenter
```
</Infobox>
## Sentencizer.\_\_init\_\_ {#init tag="method"}
Initialize the sentencizer.
> #### Example
>
> ```python
> # Construction via create_pipe
> sentencizer = nlp.create_pipe("sentencizer")
>
> # Construction from class
> from spacy.pipeline import Sentencizer
> sentencizer = Sentencizer()
> ```
| Name | Type | Description |
| ------------- | ------------- | ------------------------------------------------------------------------------------------------------ |
| `punct_chars` | list | Optional custom list of punctuation characters that mark sentence ends. Defaults to `[".", "!", "?"].` |
| **RETURNS** | `Sentencizer` | The newly constructed object. |
## Sentencizer.\_\_call\_\_ {#call tag="method"}
Apply the sentencizer on a `Doc`. Typically, this happens automatically after
the component has been added to the pipeline using
[`nlp.add_pipe`](/api/language#add_pipe).
> #### Example
>
> ```python
> from spacy.lang.en import English
>
> nlp = English()
> sentencizer = nlp.create_pipe("sentencizer")
> nlp.add_pipe(sentencizer)
> doc = nlp(u"This is a sentence. This is another sentence.")
> assert list(doc.sents) == 2
> ```
| Name | Type | Description |
| ----------- | ----- | ------------------------------------------------------------ |
| `doc` | `Doc` | The `Doc` object to process, e.g. the `Doc` in the pipeline. |
| **RETURNS** | `Doc` | The modified `Doc` with added sentence boundaries. |
## Sentencizer.to_disk {#to_disk tag="method"}
Save the sentencizer settings (punctuation characters) a directory. Will create
a file `sentencizer.json`. This also happens automatically when you save an
`nlp` object with a sentencizer added to its pipeline.
> #### Example
>
> ```python
> sentencizer = Sentencizer(punct_chars=[".", "?", "!", "。"])
> sentencizer.to_disk("/path/to/sentencizer.jsonl")
> ```
| Name | Type | Description |
| ------ | ---------------- | ---------------------------------------------------------------------------------------------------------------- |
| `path` | unicode / `Path` | A path to a file, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
## Sentencizer.from_disk {#from_disk tag="method"}
Load the sentencizer settings from a file. Expects a JSON file. This also
happens automatically when you load an `nlp` object or model with a sentencizer
added to its pipeline.
> #### Example
>
> ```python
> sentencizer = Sentencizer()
> sentencizer.from_disk("/path/to/sentencizer.json")
> ```
| Name | Type | Description |
| ----------- | ---------------- | -------------------------------------------------------------------------- |
| `path` | unicode / `Path` | A path to a JSON file. Paths may be either strings or `Path`-like objects. |
| **RETURNS** | `Sentencizer` | The modified `Sentencizer` object. |
## Sentencizer.to_bytes {#to_bytes tag="method"}
Serialize the sentencizer settings to a bytestring.
> #### Example
>
> ```python
> sentencizer = Sentencizer(punct_chars=[".", "?", "!", "。"])
> sentencizer_bytes = sentencizer.to_bytes()
> ```
| Name | Type | Description |
| ----------- | ----- | -------------------- |
| **RETURNS** | bytes | The serialized data. |
## Sentencizer.from_bytes {#from_bytes tag="method"}
Load the pipe from a bytestring. Modifies the object in place and returns it.
> #### Example
>
> ```python
> sentencizer_bytes = sentencizer.to_bytes()
> sentencizer = Sentencizer()
> sentencizer.from_bytes(sentencizer_bytes)
> ```
| Name | Type | Description |
| ------------ | ------------- | ---------------------------------- |
| `bytes_data` | bytes | The bytestring to load. |
| **RETURNS** | `Sentencizer` | The modified `Sentencizer` object. |

View File

@ -56,10 +56,11 @@ of problems. To handle a wider variety of problems, the `TextCategorizer` object
allows configuration of its model architecture, using the `architecture` keyword allows configuration of its model architecture, using the `architecture` keyword
argument. argument.
| Name | Description | | Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `"ensemble"` | **Default:** Stacked ensemble of a unigram bag-of-words model and a neural network model. The neural network uses a CNN with mean pooling and attention. | | `"ensemble"` | **Default:** Stacked ensemble of a bag-of-words model and a neural network model. The neural network uses a CNN with mean pooling and attention. The "ngram_size" and "attr" arguments can be used to configure the feature extraction for the bag-of-words model. |
| `"simple_cnn"` | A neural network model where token vectors are calculated using a CNN. The vectors are mean pooled and used as features in a feed-forward network. | | `"simple_cnn"` | A neural network model where token vectors are calculated using a CNN. The vectors are mean pooled and used as features in a feed-forward network. This architecture is usually less accurate than the ensemble, but runs faster. |
| `"bow"` | An ngram "bag-of-words" model. This architecture should run much faster than the others, but may not be as accurate, especially if texts are short. The features extracted can be controlled using the keyword arguments `ngram_size` and `attr`. For instance, `ngram_size=3` and `attr="lower"` would give lower-cased unigram, trigram and bigram features. 2, 3 or 4 are usually good choices of ngram size. |
## TextCategorizer.\_\_call\_\_ {#call tag="method"} ## TextCategorizer.\_\_call\_\_ {#call tag="method"}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@ -25,21 +25,21 @@ an **annotated document**. It also orchestrates training and serialization.
### Processing pipeline {#architecture-pipeline} ### Processing pipeline {#architecture-pipeline}
| Name | Description | | Name | Description |
| --------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------- | | ------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| [`Language`](/api/language) | A text-processing pipeline. Usually you'll load this once per process as `nlp` and pass the instance around your application. | | [`Language`](/api/language) | A text-processing pipeline. Usually you'll load this once per process as `nlp` and pass the instance around your application. |
| [`Tokenizer`](/api/tokenizer) | Segment text, and create `Doc` objects with the discovered segment boundaries. | | [`Tokenizer`](/api/tokenizer) | Segment text, and create `Doc` objects with the discovered segment boundaries. |
| [`Lemmatizer`](/api/lemmatizer) | Determine the base forms of words. | | [`Lemmatizer`](/api/lemmatizer) | Determine the base forms of words. |
| `Morphology` | Assign linguistic features like lemmas, noun case, verb tense etc. based on the word and its part-of-speech tag. | | `Morphology` | Assign linguistic features like lemmas, noun case, verb tense etc. based on the word and its part-of-speech tag. |
| [`Tagger`](/api/tagger) | Annotate part-of-speech tags on `Doc` objects. | | [`Tagger`](/api/tagger) | Annotate part-of-speech tags on `Doc` objects. |
| [`DependencyParser`](/api/dependencyparser) | Annotate syntactic dependencies on `Doc` objects. | | [`DependencyParser`](/api/dependencyparser) | Annotate syntactic dependencies on `Doc` objects. |
| [`EntityRecognizer`](/api/entityrecognizer) | Annotate named entities, e.g. persons or products, on `Doc` objects. | | [`EntityRecognizer`](/api/entityrecognizer) | Annotate named entities, e.g. persons or products, on `Doc` objects. |
| [`TextCategorizer`](/api/textcategorizer) | Assign categories or labels to `Doc` objects. | | [`TextCategorizer`](/api/textcategorizer) | Assign categories or labels to `Doc` objects. |
| [`Matcher`](/api/matcher) | Match sequences of tokens, based on pattern rules, similar to regular expressions. | | [`Matcher`](/api/matcher) | Match sequences of tokens, based on pattern rules, similar to regular expressions. |
| [`PhraseMatcher`](/api/phrasematcher) | Match sequences of tokens based on phrases. | | [`PhraseMatcher`](/api/phrasematcher) | Match sequences of tokens based on phrases. |
| [`EntityRuler`](/api/entityruler) | Add entity spans to the `Doc` using token-based rules or exact phrase matches. | | [`EntityRuler`](/api/entityruler) | Add entity spans to the `Doc` using token-based rules or exact phrase matches. |
| [`SentenceSegmenter`](/api/sentencesegmenter) | Implement custom sentence boundary detection logic that doesn't require the dependency parse. | | [`Sentencizer`](/api/sentencizer) | Implement custom sentence boundary detection logic that doesn't require the dependency parse. |
| [Other functions](/api/pipeline-functions) | Automatically apply something to the `Doc`, e.g. to merge spans of tokens. | | [Other functions](/api/pipeline-functions) | Automatically apply something to the `Doc`, e.g. to merge spans of tokens. |
### Other classes {#architecture-other} ### Other classes {#architecture-other}

View File

@ -8,7 +8,7 @@ menu:
- ['Changelog', 'changelog'] - ['Changelog', 'changelog']
--- ---
spaCy is compatible with **64-bit CPython 2.7+/3.5+** and runs on spaCy is compatible with **64-bit CPython 2.7 / 3.5+** and runs on
**Unix/Linux**, **macOS/OS X** and **Windows**. The latest spaCy releases are **Unix/Linux**, **macOS/OS X** and **Windows**. The latest spaCy releases are
available over [pip](https://pypi.python.org/pypi/spacy) and available over [pip](https://pypi.python.org/pypi/spacy) and
[conda](https://anaconda.org/conda-forge/spacy). [conda](https://anaconda.org/conda-forge/spacy).

View File

@ -1149,9 +1149,14 @@ but it also means you'll need a **statistical model** and accurate predictions.
If your texts are closer to general-purpose news or web text, this should work If your texts are closer to general-purpose news or web text, this should work
well out-of-the-box. For social media or conversational text that doesn't follow well out-of-the-box. For social media or conversational text that doesn't follow
the same rules, your application may benefit from a custom rule-based the same rules, your application may benefit from a custom rule-based
implementation. You can either plug a rule-based component into your implementation. You can either use the built-in
[processing pipeline](/usage/processing-pipelines) or use the [`Sentencizer`](/api/sentencizer) or plug an entirely custom rule-based function
`SentenceSegmenter` component with a custom strategy. into your [processing pipeline](/usage/processing-pipelines).
spaCy's dependency parser respects already set boundaries, so you can preprocess
your `Doc` using custom rules _before_ it's parsed. Depending on your text, this
may also improve accuracy, since the parser is constrained to predict parses
consistent with the sentence boundaries.
### Default: Using the dependency parse {#sbd-parser model="parser"} ### Default: Using the dependency parse {#sbd-parser model="parser"}
@ -1168,13 +1173,35 @@ for sent in doc.sents:
print(sent.text) print(sent.text)
``` ```
### Setting boundaries manually {#sbd-manual} ### Rule-based pipeline component {#sbd-component}
spaCy's dependency parser respects already set boundaries, so you can preprocess The [`Sentencizer`](/api/sentencizer) component is a
your `Doc` using custom rules _before_ it's parsed. This can be done by adding a [pipeline component](/usage/processing-pipelines) that splits sentences on
[custom pipeline component](/usage/processing-pipelines). Depending on your punctuation like `.`, `!` or `?`. You can plug it into your pipeline if you only
text, this may also improve accuracy, since the parser is constrained to predict need sentence boundaries without the dependency parse.
parses consistent with the sentence boundaries.
```python
### {executable="true"}
import spacy
from spacy.lang.en import English
nlp = English() # just the language with no model
sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer)
doc = nlp(u"This is a sentence. This is another sentence.")
for sent in doc.sents:
print(sent.text)
```
### Custom rule-based strategy {id="sbd-custom"}
If you want to implement your own strategy that differs from the default
rule-based approach of splitting on sentences, you can also create a
[custom pipeline component](/usage/processing-pipelines#custom-components) that
takes a `Doc` object and sets the `Token.is_sent_start` attribute on each
individual token. If set to `False`, the token is explicitly marked as _not_ the
start of a sentence. If set to `None` (default), it's treated as a missing value
and can still be overwritten by the parser.
<Infobox title="Important note" variant="warning"> <Infobox title="Important note" variant="warning">
@ -1187,9 +1214,11 @@ adding it to the pipeline using [`nlp.add_pipe`](/api/language#add_pipe).
Here's an example of a component that implements a pre-processing rule for Here's an example of a component that implements a pre-processing rule for
splitting on `'...'` tokens. The component is added before the parser, which is splitting on `'...'` tokens. The component is added before the parser, which is
then used to further segment the text. This approach can be useful if you want then used to further segment the text. That's possible, because `is_sent_start`
to implement **additional** rules specific to your data, while still being able is only set to `True` for some of the tokens all others still specify `None`
to take advantage of dependency-based sentence segmentation. for unset sentence boundaries. This approach can be useful if you want to
implement **additional** rules specific to your data, while still being able to
take advantage of dependency-based sentence segmentation.
```python ```python
### {executable="true"} ### {executable="true"}
@ -1212,62 +1241,6 @@ doc = nlp(text)
print("After:", [sent.text for sent in doc.sents]) print("After:", [sent.text for sent in doc.sents])
``` ```
### Rule-based pipeline component {#sbd-component}
The `sentencizer` component is a
[pipeline component](/usage/processing-pipelines) that splits sentences on
punctuation like `.`, `!` or `?`. You can plug it into your pipeline if you only
need sentence boundaries without the dependency parse. Note that `Doc.sents`
will **raise an error** if no sentence boundaries are set.
```python
### {executable="true"}
import spacy
from spacy.lang.en import English
nlp = English() # just the language with no model
sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer)
doc = nlp(u"This is a sentence. This is another sentence.")
for sent in doc.sents:
print(sent.text)
```
### Custom rule-based strategy {#sbd-custom}
If you want to implement your own strategy that differs from the default
rule-based approach of splitting on sentences, you can also instantiate the
`SentenceSegmenter` directly and pass in your own strategy. The strategy should
be a function that takes a `Doc` object and yields a `Span` for each sentence.
Here's an example of a custom segmentation strategy for splitting on newlines
only:
```python
### {executable="true"}
from spacy.lang.en import English
from spacy.pipeline import SentenceSegmenter
def split_on_newlines(doc):
start = 0
seen_newline = False
for word in doc:
if seen_newline and not word.is_space:
yield doc[start:word.i]
start = word.i
seen_newline = False
elif word.text == '\\n':
seen_newline = True
if start < len(doc):
yield doc[start:len(doc)]
nlp = English() # Just the language with no model
sentencizer = SentenceSegmenter(nlp.vocab, strategy=split_on_newlines)
nlp.add_pipe(sentencizer)
doc = nlp(u"This is a sentence\\n\\nThis is another sentence\\nAnd more")
for sent in doc.sents:
print([token.text for token in sent])
```
## Rule-based matching {#rule-based-matching hidden="true"} ## Rule-based matching {#rule-based-matching hidden="true"}
<div id="rule-based-matching"> <div id="rule-based-matching">

View File

@ -138,7 +138,7 @@ require them in the pipeline settings in your model's `meta.json`.
| `ner` | [`EntityRecognizer`](/api/entityrecognizer) | Assign named entities. | | `ner` | [`EntityRecognizer`](/api/entityrecognizer) | Assign named entities. |
| `textcat` | [`TextCategorizer`](/api/textcategorizer) | Assign text categories. | | `textcat` | [`TextCategorizer`](/api/textcategorizer) | Assign text categories. |
| `entity_ruler` | [`EntityRuler`](/api/entityruler) | Assign named entities based on pattern rules. | | `entity_ruler` | [`EntityRuler`](/api/entityruler) | Assign named entities based on pattern rules. |
| `sentencizer` | [`SentenceSegmenter`](/api/sentencesegmenter) | Add rule-based sentence segmentation without the dependency parse. | | `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. |
| `merge_noun_chunks` | [`merge_noun_chunks`](/api/pipeline-functions#merge_noun_chunks) | Merge all noun chunks into a single token. Should be added after the tagger and parser. | | `merge_noun_chunks` | [`merge_noun_chunks`](/api/pipeline-functions#merge_noun_chunks) | Merge all noun chunks into a single token. Should be added after the tagger and parser. |
| `merge_entities` | [`merge_entities`](/api/pipeline-functions#merge_entities) | Merge all entities into a single token. Should be added after the entity recognizer. | | `merge_entities` | [`merge_entities`](/api/pipeline-functions#merge_entities) | Merge all entities into a single token. Should be added after the entity recognizer. |
| `merge_subtokens` | [`merge_subtokens`](/api/pipeline-functions#merge_subtokens) | Merge subtokens predicted by the parser into single tokens. Should be added after the parser. | | `merge_subtokens` | [`merge_subtokens`](/api/pipeline-functions#merge_subtokens) | Merge subtokens predicted by the parser into single tokens. Should be added after the parser. |

View File

@ -404,7 +404,7 @@ class BadHTMLMerger(object):
for match_id, start, end in matches: for match_id, start, end in matches:
spans.append(doc[start:end]) spans.append(doc[start:end])
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
for span in spans: for span in hashtags:
retokenizer.merge(span) retokenizer.merge(span)
for token in span: for token in span:
token._.bad_html = True # Mark token as bad HTML token._.bad_html = True # Mark token as bad HTML

View File

@ -95,6 +95,21 @@ systems, or to pre-process text for **deep learning**.
publishing spaCy and other software is called publishing spaCy and other software is called
[Explosion AI](https://explosion.ai). [Explosion AI](https://explosion.ai).
<Infobox title="Download the spaCy Cheat Sheet!">
[![spaCy Cheatsheet](../images/cheatsheet.jpg)](http://datacamp-community-prod.s3.amazonaws.com/29aa28bf-570a-4965-8f54-d6a541ae4e06)
For the launch of our
["Advanced NLP with spaCy"](https://www.datacamp.com/courses/advanced-nlp-with-spacy)
course on DataCamp we created the first official spaCy cheat sheet! A handy
two-page reference to the most important concepts and features, from loading
models and accessing linguistic annotations, to custom pipeline components and
rule-based matching.
<p><Button to="http://datacamp-community-prod.s3.amazonaws.com/29aa28bf-570a-4965-8f54-d6a541ae4e06" variant="primary">Download</Button></p>
</Infobox>
## Features {#features} ## Features {#features}
In the documentation, you'll come across mentions of spaCy's features and In the documentation, you'll come across mentions of spaCy's features and

View File

@ -13,6 +13,8 @@ design changes introduced in [v2.0](/usage/v2). As well as smaller models,
faster runtime, and many bug fixes, v2.1 also introduces experimental support faster runtime, and many bug fixes, v2.1 also introduces experimental support
for some exciting new NLP innovations. For the full changelog, see the for some exciting new NLP innovations. For the full changelog, see the
[release notes on GitHub](https://github.com/explosion/spaCy/releases/tag/v2.1.0). [release notes on GitHub](https://github.com/explosion/spaCy/releases/tag/v2.1.0).
For more details and a behind-the-scenes look at the new release,
[see our blog post](https://explosion.ai/blog/spacy-v2-1).
### BERT/ULMFit/Elmo-style pre-training {#pretraining tag="experimental"} ### BERT/ULMFit/Elmo-style pre-training {#pretraining tag="experimental"}
@ -195,7 +197,7 @@ the existing pages and added some new content:
- **Universe:** [Videos](/universe/category/videos) and - **Universe:** [Videos](/universe/category/videos) and
[Podcasts](/universe/category/podcasts) [Podcasts](/universe/category/podcasts)
- **API:** [`EntityRuler`](/api/entityruler) - **API:** [`EntityRuler`](/api/entityruler)
- **API:** [`SentenceSegmenter`](/api/sentencesegmenter) - **API:** [`Sentencizer`](/api/sentencizer)
- **API:** [Pipeline functions](/api/pipeline-functions) - **API:** [Pipeline functions](/api/pipeline-functions)
## Backwards incompatibilities {#incompat} ## Backwards incompatibilities {#incompat}
@ -212,9 +214,8 @@ if all of your models are up to date, you can run the
- Due to difficulties linking our new - Due to difficulties linking our new
[`blis`](https://github.com/explosion/cython-blis) for faster [`blis`](https://github.com/explosion/cython-blis) for faster
platform-independent matrix multiplication, this nightly release currently platform-independent matrix multiplication, this release currently **doesn't
**doesn't work on Python 2.7 on Windows**. We expect this to be corrected in work on Python 2.7 on Windows**. We expect this to be corrected in the future.
the future.
- While the [`Matcher`](/api/matcher) API is fully backwards compatible, its - While the [`Matcher`](/api/matcher) API is fully backwards compatible, its
algorithm has changed to fix a number of bugs and performance issues. This algorithm has changed to fix a number of bugs and performance issues. This
@ -250,9 +251,14 @@ if all of your models are up to date, you can run the
+ data = nlp.tokenizer.to_bytes(exclude=["vocab"]) + data = nlp.tokenizer.to_bytes(exclude=["vocab"])
``` ```
- The .pos value for several common English words has changed, due to
corrections to long-standing mistakes in the English tag map (see
[issue #593](https://github.com/explosion/spaCy/issues/593) and
[issue #3311](https://github.com/explosion/spaCy/issues/3311) for details).
- For better compatibility with the Universal Dependencies data, the lemmatizer - For better compatibility with the Universal Dependencies data, the lemmatizer
now preserves capitalization, e.g. for proper nouns. See now preserves capitalization, e.g. for proper nouns. See
[this issue](https://github.com/explosion/spaCy/issues/3256) for details. [issue #3256](https://github.com/explosion/spaCy/issues/3256) for details.
- The built-in rule-based sentence boundary detector is now only called - The built-in rule-based sentence boundary detector is now only called
`"sentencizer"` the name `"sbd"` is deprecated. `"sentencizer"` the name `"sbd"` is deprecated.

View File

@ -117,6 +117,7 @@
{ "code": "sk", "name": "Slovak" }, { "code": "sk", "name": "Slovak" },
{ "code": "sl", "name": "Slovenian" }, { "code": "sl", "name": "Slovenian" },
{ "code": "sq", "name": "Albanian" }, { "code": "sq", "name": "Albanian" },
{ "code": "et", "name": "Estonian" },
{ {
"code": "th", "code": "th",
"name": "Thai", "name": "Thai",

View File

@ -79,7 +79,7 @@
{ "text": "Matcher", "url": "/api/matcher" }, { "text": "Matcher", "url": "/api/matcher" },
{ "text": "PhraseMatcher", "url": "/api/phrasematcher" }, { "text": "PhraseMatcher", "url": "/api/phrasematcher" },
{ "text": "EntityRuler", "url": "/api/entityruler" }, { "text": "EntityRuler", "url": "/api/entityruler" },
{ "text": "SentenceSegmenter", "url": "/api/sentencesegmenter" }, { "text": "Sentencizer", "url": "/api/sentencizer" },
{ "text": "Other Functions", "url": "/api/pipeline-functions" } { "text": "Other Functions", "url": "/api/pipeline-functions" }
] ]
}, },

View File

@ -28,8 +28,8 @@
}, },
"spacyVersion": "2.1", "spacyVersion": "2.1",
"binderUrl": "ines/spacy-io-binder", "binderUrl": "ines/spacy-io-binder",
"binderBranch": "nightly", "binderBranch": "live",
"binderVersion": "2.1.0a9", "binderVersion": "2.1.3",
"sections": [ "sections": [
{ "id": "usage", "title": "Usage Documentation", "theme": "blue" }, { "id": "usage", "title": "Usage Documentation", "theme": "blue" },
{ "id": "models", "title": "Models Documentation", "theme": "blue" }, { "id": "models", "title": "Models Documentation", "theme": "blue" },

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 24 KiB

View File

@ -31,24 +31,21 @@ import spacy
nlp = spacy.load("en_core_web_sm") nlp = spacy.load("en_core_web_sm")
# Process whole documents # Process whole documents
text = (u"When Sebastian Thrun started working on self-driving cars at " text = ("When Sebastian Thrun started working on self-driving cars at "
u"Google in 2007, few people outside of the company took him " "Google in 2007, few people outside of the company took him "
u"seriously. “I can tell you very senior CEOs of major American " "seriously. “I can tell you very senior CEOs of major American "
u"car companies would shake my hand and turn away because I wasnt " "car companies would shake my hand and turn away because I wasnt "
u"worth talking to,” said Thrun, now the co-founder and CEO of " "worth talking to,” said Thrun, in an interview with Recode earlier "
u"online higher education startup Udacity, in an interview with " "this week.")
u"Recode earlier this week.")
doc = nlp(text) doc = nlp(text)
# Analyze syntax
print("Noun phrases:", [chunk.text for chunk in doc.noun_chunks])
print("Verbs:", [token.lemma_ for token in doc if token.pos_ == "VERB"])
# Find named entities, phrases and concepts # Find named entities, phrases and concepts
for entity in doc.ents: for entity in doc.ents:
print(entity.text, entity.label_) print(entity.text, entity.label_)
# Determine semantic similarities
doc1 = nlp(u"my fries were super gross")
doc2 = nlp(u"such disgusting fries")
similarity = doc1.similarity(doc2)
print(doc1.text, doc2.text, similarity)
` `
/** /**
@ -218,7 +215,7 @@ const Landing = ({ data }) => {
<H2>Benchmarks</H2> <H2>Benchmarks</H2>
<p> <p>
In 2015, independent researchers from Emory University and Yahoo! Labs In 2015, independent researchers from Emory University and Yahoo! Labs
showed that spaCy offered the showed that spaCy offered the{' '}
<strong>fastest syntactic parser in the world</strong> and that its accuracy <strong>fastest syntactic parser in the world</strong> and that its accuracy
was <strong>within 1% of the best</strong> available ( was <strong>within 1% of the best</strong> available (
<Link to="https://aclweb.org/anthology/P/P15/P15-1038.pdf"> <Link to="https://aclweb.org/anthology/P/P15/P15-1038.pdf">