Fix pickle for ngram suggester (#12486)

This commit is contained in:
Adriane Boyd 2023-03-31 13:43:51 +02:00 committed by GitHub
parent 140d53649d
commit 69e20ce03d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 28 deletions

View File

@ -1,5 +1,6 @@
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d from thinc.types import Ragged, Ints2d, Floats2d
@ -82,39 +83,42 @@ class Suggester(Protocol):
... ...
def ngram_suggester(
docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None
) -> Ragged:
if ops is None:
ops = get_current_ops()
spans = []
lengths = []
for doc in docs:
starts = ops.xp.arange(len(doc), dtype="i")
starts = starts.reshape((-1, 1))
length = 0
for size in sizes:
if size <= len(doc):
starts_size = starts[: len(doc) - (size - 1)]
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
length += spans[-1].shape[0]
if spans:
assert spans[-1].ndim == 2, spans[-1].shape
lengths.append(length)
lengths_array = ops.asarray1i(lengths)
if len(spans) > 0:
output = Ragged(ops.xp.vstack(spans), lengths_array)
else:
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
assert output.dataXd.ndim == 2
return output
@registry.misc("spacy.ngram_suggester.v1") @registry.misc("spacy.ngram_suggester.v1")
def build_ngram_suggester(sizes: List[int]) -> Suggester: def build_ngram_suggester(sizes: List[int]) -> Suggester:
"""Suggest all spans of the given lengths. Spans are returned as a ragged """Suggest all spans of the given lengths. Spans are returned as a ragged
array of integers. The array has two columns, indicating the start and end array of integers. The array has two columns, indicating the start and end
position.""" position."""
def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged: return partial(ngram_suggester, sizes=sizes)
if ops is None:
ops = get_current_ops()
spans = []
lengths = []
for doc in docs:
starts = ops.xp.arange(len(doc), dtype="i")
starts = starts.reshape((-1, 1))
length = 0
for size in sizes:
if size <= len(doc):
starts_size = starts[: len(doc) - (size - 1)]
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
length += spans[-1].shape[0]
if spans:
assert spans[-1].ndim == 2, spans[-1].shape
lengths.append(length)
lengths_array = ops.asarray1i(lengths)
if len(spans) > 0:
output = Ragged(ops.xp.vstack(spans), lengths_array)
else:
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
assert output.dataXd.ndim == 2
return output
return ngram_suggester
@registry.misc("spacy.ngram_range_suggester.v1") @registry.misc("spacy.ngram_range_suggester.v1")

View File

@ -1,7 +1,7 @@
import pytest import pytest
import numpy import numpy
from numpy.testing import assert_array_equal, assert_almost_equal from numpy.testing import assert_array_equal, assert_almost_equal
from thinc.api import get_current_ops, Ragged from thinc.api import get_current_ops, NumpyOps, Ragged
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
@ -577,3 +577,21 @@ def test_set_candidates(name):
assert len(docs[0].spans["candidates"]) == 9 assert len(docs[0].spans["candidates"]) == 9
assert docs[0].spans["candidates"][0].text == "Just" assert docs[0].spans["candidates"][0].text == "Just"
assert docs[0].spans["candidates"][4].text == "Just a" assert docs[0].spans["candidates"][4].text == "Just a"
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
@pytest.mark.parametrize("n_process", [1, 2])
def test_spancat_multiprocessing(name, n_process):
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
nlp = Language()
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
texts = [
"Just a sentence.",
"I like London and Berlin",
"I like Berlin",
"I eat ham.",
]
docs = list(nlp.pipe(texts, n_process=n_process))
assert len(docs) == len(texts)