mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
Fix pickle for ngram suggester (#12486)
This commit is contained in:
parent
140d53649d
commit
69e20ce03d
|
@ -1,5 +1,6 @@
|
||||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||||
from thinc.api import Optimizer
|
from thinc.api import Optimizer
|
||||||
from thinc.types import Ragged, Ints2d, Floats2d
|
from thinc.types import Ragged, Ints2d, Floats2d
|
||||||
|
@ -82,39 +83,42 @@ class Suggester(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def ngram_suggester(
|
||||||
|
docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None
|
||||||
|
) -> Ragged:
|
||||||
|
if ops is None:
|
||||||
|
ops = get_current_ops()
|
||||||
|
spans = []
|
||||||
|
lengths = []
|
||||||
|
for doc in docs:
|
||||||
|
starts = ops.xp.arange(len(doc), dtype="i")
|
||||||
|
starts = starts.reshape((-1, 1))
|
||||||
|
length = 0
|
||||||
|
for size in sizes:
|
||||||
|
if size <= len(doc):
|
||||||
|
starts_size = starts[: len(doc) - (size - 1)]
|
||||||
|
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
|
||||||
|
length += spans[-1].shape[0]
|
||||||
|
if spans:
|
||||||
|
assert spans[-1].ndim == 2, spans[-1].shape
|
||||||
|
lengths.append(length)
|
||||||
|
lengths_array = ops.asarray1i(lengths)
|
||||||
|
if len(spans) > 0:
|
||||||
|
output = Ragged(ops.xp.vstack(spans), lengths_array)
|
||||||
|
else:
|
||||||
|
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
||||||
|
|
||||||
|
assert output.dataXd.ndim == 2
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.ngram_suggester.v1")
|
@registry.misc("spacy.ngram_suggester.v1")
|
||||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||||
array of integers. The array has two columns, indicating the start and end
|
array of integers. The array has two columns, indicating the start and end
|
||||||
position."""
|
position."""
|
||||||
|
|
||||||
def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
|
return partial(ngram_suggester, sizes=sizes)
|
||||||
if ops is None:
|
|
||||||
ops = get_current_ops()
|
|
||||||
spans = []
|
|
||||||
lengths = []
|
|
||||||
for doc in docs:
|
|
||||||
starts = ops.xp.arange(len(doc), dtype="i")
|
|
||||||
starts = starts.reshape((-1, 1))
|
|
||||||
length = 0
|
|
||||||
for size in sizes:
|
|
||||||
if size <= len(doc):
|
|
||||||
starts_size = starts[: len(doc) - (size - 1)]
|
|
||||||
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
|
|
||||||
length += spans[-1].shape[0]
|
|
||||||
if spans:
|
|
||||||
assert spans[-1].ndim == 2, spans[-1].shape
|
|
||||||
lengths.append(length)
|
|
||||||
lengths_array = ops.asarray1i(lengths)
|
|
||||||
if len(spans) > 0:
|
|
||||||
output = Ragged(ops.xp.vstack(spans), lengths_array)
|
|
||||||
else:
|
|
||||||
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
|
||||||
|
|
||||||
assert output.dataXd.ndim == 2
|
|
||||||
return output
|
|
||||||
|
|
||||||
return ngram_suggester
|
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.ngram_range_suggester.v1")
|
@registry.misc("spacy.ngram_range_suggester.v1")
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_array_equal, assert_almost_equal
|
from numpy.testing import assert_array_equal, assert_almost_equal
|
||||||
from thinc.api import get_current_ops, Ragged
|
from thinc.api import get_current_ops, NumpyOps, Ragged
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
@ -577,3 +577,21 @@ def test_set_candidates(name):
|
||||||
assert len(docs[0].spans["candidates"]) == 9
|
assert len(docs[0].spans["candidates"]) == 9
|
||||||
assert docs[0].spans["candidates"][0].text == "Just"
|
assert docs[0].spans["candidates"][0].text == "Just"
|
||||||
assert docs[0].spans["candidates"][4].text == "Just a"
|
assert docs[0].spans["candidates"][4].text == "Just a"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||||
|
@pytest.mark.parametrize("n_process", [1, 2])
|
||||||
|
def test_spancat_multiprocessing(name, n_process):
|
||||||
|
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
|
||||||
|
nlp = Language()
|
||||||
|
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||||
|
train_examples = make_examples(nlp)
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
texts = [
|
||||||
|
"Just a sentence.",
|
||||||
|
"I like London and Berlin",
|
||||||
|
"I like Berlin",
|
||||||
|
"I eat ham.",
|
||||||
|
]
|
||||||
|
docs = list(nlp.pipe(texts, n_process=n_process))
|
||||||
|
assert len(docs) == len(texts)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user