mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
merge SpanPredictor attributes
This commit is contained in:
parent
a872c69ffb
commit
1c5dabcb47
|
@ -83,7 +83,7 @@ def build_span_predictor(
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor = PyTorchWrapper(
|
||||||
SpanPredictor(dim, dist_emb_size),
|
SpanPredictor(dim, hidden_size, dist_emb_size),
|
||||||
convert_inputs=convert_span_predictor_inputs
|
convert_inputs=convert_span_predictor_inputs
|
||||||
)
|
)
|
||||||
# TODO use proper parameter for prefix
|
# TODO use proper parameter for prefix
|
||||||
|
@ -511,11 +511,7 @@ class RoughScorer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class SpanPredictor(torch.nn.Module):
|
class SpanPredictor(torch.nn.Module):
|
||||||
<<<<<<< HEAD
|
|
||||||
def __init__(self, input_size: int, distance_emb_size: int):
|
|
||||||
=======
|
|
||||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
|
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
|
||||||
>>>>>>> eec00ce60d83f500e18f2da7d9feafa7143440f2
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# input size = single token size
|
# input size = single token size
|
||||||
# 64 = probably distance emb size
|
# 64 = probably distance emb size
|
||||||
|
|
Loading…
Reference in New Issue
Block a user