mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
Let Langugae.use_params work with falsey inputs
The Language.use_params method was failing if you passed in None, which meant we had to use awkward conditionals for the parameter averaging. This solves the problem.
This commit is contained in:
parent
122cb02001
commit
ef0d0630a4
|
@ -1,5 +1,5 @@
|
||||||
from typing import Optional, Any, Dict, Callable, Iterable, Union, List, Pattern
|
from typing import Optional, Any, Dict, Callable, Iterable, Union, List, Pattern
|
||||||
from typing import Tuple, Iterator
|
from typing import Tuple, Iterator, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import random
|
import random
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -1275,7 +1275,7 @@ class Language:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_params(self, params: dict):
|
def use_params(self, params: Optional[dict]):
|
||||||
"""Replace weights of models in the pipeline with those provided in the
|
"""Replace weights of models in the pipeline with those provided in the
|
||||||
params dictionary. Can be used as a contextmanager, in which case,
|
params dictionary. Can be used as a contextmanager, in which case,
|
||||||
models go back to their original weights after the block.
|
models go back to their original weights after the block.
|
||||||
|
@ -1288,24 +1288,27 @@ class Language:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#use_params
|
DOCS: https://spacy.io/api/language#use_params
|
||||||
"""
|
"""
|
||||||
contexts = [
|
if not params:
|
||||||
pipe.use_params(params)
|
yield
|
||||||
for name, pipe in self.pipeline
|
else:
|
||||||
if hasattr(pipe, "use_params") and hasattr(pipe, "model")
|
contexts = [
|
||||||
]
|
pipe.use_params(params)
|
||||||
# TODO: Having trouble with contextlib
|
for name, pipe in self.pipeline
|
||||||
# Workaround: these aren't actually context managers atm.
|
if hasattr(pipe, "use_params") and hasattr(pipe, "model")
|
||||||
for context in contexts:
|
]
|
||||||
try:
|
# TODO: Having trouble with contextlib
|
||||||
next(context)
|
# Workaround: these aren't actually context managers atm.
|
||||||
except StopIteration:
|
for context in contexts:
|
||||||
pass
|
try:
|
||||||
yield
|
next(context)
|
||||||
for context in contexts:
|
except StopIteration:
|
||||||
try:
|
pass
|
||||||
next(context)
|
yield
|
||||||
except StopIteration:
|
for context in contexts:
|
||||||
pass
|
try:
|
||||||
|
next(context)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
def pipe(
|
def pipe(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user