Fix pickle for ngram suggester (#12486)

This commit is contained in:
Adriane Boyd 2023-03-31 13:43:51 +02:00 committed by GitHub
parent 140d53649d
commit 69e20ce03d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 28 deletions

View File

@ -1,5 +1,6 @@
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
from dataclasses import dataclass
from functools import partial
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d
@ -82,13 +83,9 @@ class Suggester(Protocol):
...
@registry.misc("spacy.ngram_suggester.v1")
def build_ngram_suggester(sizes: List[int]) -> Suggester:
"""Suggest all spans of the given lengths. Spans are returned as a ragged
array of integers. The array has two columns, indicating the start and end
position."""
def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
def ngram_suggester(
docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None
) -> Ragged:
if ops is None:
ops = get_current_ops()
spans = []
@ -114,7 +111,14 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
assert output.dataXd.ndim == 2
return output
return ngram_suggester
@registry.misc("spacy.ngram_suggester.v1")
def build_ngram_suggester(sizes: List[int]) -> Suggester:
"""Suggest all spans of the given lengths. Spans are returned as a ragged
array of integers. The array has two columns, indicating the start and end
position."""
return partial(ngram_suggester, sizes=sizes)
@registry.misc("spacy.ngram_range_suggester.v1")

View File

@ -1,7 +1,7 @@
import pytest
import numpy
from numpy.testing import assert_array_equal, assert_almost_equal
from thinc.api import get_current_ops, Ragged
from thinc.api import get_current_ops, NumpyOps, Ragged
from spacy import util
from spacy.lang.en import English
@ -577,3 +577,21 @@ def test_set_candidates(name):
assert len(docs[0].spans["candidates"]) == 9
assert docs[0].spans["candidates"][0].text == "Just"
assert docs[0].spans["candidates"][4].text == "Just a"
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
@pytest.mark.parametrize("n_process", [1, 2])
def test_spancat_multiprocessing(name, n_process):
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
nlp = Language()
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
texts = [
"Just a sentence.",
"I like London and Berlin",
"I like Berlin",
"I eat ham.",
]
docs = list(nlp.pipe(texts, n_process=n_process))
assert len(docs) == len(texts)