diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 5fbb64a29..5fe29c25f 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -83,7 +83,7 @@ def build_span_predictor( with Model.define_operators({">>": chain, "&": tuplify}): span_predictor = PyTorchWrapper( - SpanPredictor(dim, dist_emb_size), + SpanPredictor(dim, hidden_size, dist_emb_size), convert_inputs=convert_span_predictor_inputs ) # TODO use proper parameter for prefix @@ -511,11 +511,7 @@ class RoughScorer(torch.nn.Module): class SpanPredictor(torch.nn.Module): -<<<<<<< HEAD - def __init__(self, input_size: int, distance_emb_size: int): -======= def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device): ->>>>>>> eec00ce60d83f500e18f2da7d9feafa7143440f2 super().__init__() # input size = single token size # 64 = probably distance emb size