Let Langugae.use_params work with falsey inputs

The Language.use_params method was failing if you passed in None, which
meant we had to use awkward conditionals for the parameter averaging.
This solves the problem.
This commit is contained in:
Matthew Honnibal 2020-09-03 12:51:04 +02:00
parent 122cb02001
commit ef0d0630a4

View File

@ -1,5 +1,5 @@
from typing import Optional, Any, Dict, Callable, Iterable, Union, List, Pattern from typing import Optional, Any, Dict, Callable, Iterable, Union, List, Pattern
from typing import Tuple, Iterator from typing import Tuple, Iterator, Optional
from dataclasses import dataclass from dataclasses import dataclass
import random import random
import itertools import itertools
@ -1275,7 +1275,7 @@ class Language:
return results return results
@contextmanager @contextmanager
def use_params(self, params: dict): def use_params(self, params: Optional[dict]):
"""Replace weights of models in the pipeline with those provided in the """Replace weights of models in the pipeline with those provided in the
params dictionary. Can be used as a contextmanager, in which case, params dictionary. Can be used as a contextmanager, in which case,
models go back to their original weights after the block. models go back to their original weights after the block.
@ -1288,24 +1288,27 @@ class Language:
DOCS: https://spacy.io/api/language#use_params DOCS: https://spacy.io/api/language#use_params
""" """
contexts = [ if not params:
pipe.use_params(params) yield
for name, pipe in self.pipeline else:
if hasattr(pipe, "use_params") and hasattr(pipe, "model") contexts = [
] pipe.use_params(params)
# TODO: Having trouble with contextlib for name, pipe in self.pipeline
# Workaround: these aren't actually context managers atm. if hasattr(pipe, "use_params") and hasattr(pipe, "model")
for context in contexts: ]
try: # TODO: Having trouble with contextlib
next(context) # Workaround: these aren't actually context managers atm.
except StopIteration: for context in contexts:
pass try:
yield next(context)
for context in contexts: except StopIteration:
try: pass
next(context) yield
except StopIteration: for context in contexts:
pass try:
next(context)
except StopIteration:
pass
def pipe( def pipe(
self, self,