mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
merge SpanPredictor attributes
This commit is contained in:
parent
a872c69ffb
commit
1c5dabcb47
|
@ -83,7 +83,7 @@ def build_span_predictor(
|
|||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(dim, dist_emb_size),
|
||||
SpanPredictor(dim, hidden_size, dist_emb_size),
|
||||
convert_inputs=convert_span_predictor_inputs
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
|
@ -511,11 +511,7 @@ class RoughScorer(torch.nn.Module):
|
|||
|
||||
|
||||
class SpanPredictor(torch.nn.Module):
|
||||
<<<<<<< HEAD
|
||||
def __init__(self, input_size: int, distance_emb_size: int):
|
||||
=======
|
||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
|
||||
>>>>>>> eec00ce60d83f500e18f2da7d9feafa7143440f2
|
||||
super().__init__()
|
||||
# input size = single token size
|
||||
# 64 = probably distance emb size
|
||||
|
|
Loading…
Reference in New Issue
Block a user