mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add option for GPU ID to pretrain
This commit is contained in:
		
							parent
							
								
									1dce86c555
								
							
						
					
					
						commit
						6c8785a238
					
				| 
						 | 
					@ -10,10 +10,11 @@ from collections import Counter
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from thinc.v2v import Affine, Maxout
 | 
					from thinc.v2v import Affine, Maxout
 | 
				
			||||||
from thinc.misc import LayerNorm as LN
 | 
					from thinc.misc import LayerNorm as LN
 | 
				
			||||||
from thinc.neural.util import prefer_gpu
 | 
					from thinc.neural.util import require_gpu
 | 
				
			||||||
from wasabi import Printer
 | 
					from wasabi import Printer
 | 
				
			||||||
import srsly
 | 
					import srsly
 | 
				
			||||||
from thinc.neural.util import to_categorical
 | 
					from thinc.neural.util import to_categorical
 | 
				
			||||||
 | 
					from thinc.rates import cyclic_triangular_rate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..errors import Errors
 | 
					from ..errors import Errors
 | 
				
			||||||
from ..tokens import Doc
 | 
					from ..tokens import Doc
 | 
				
			||||||
| 
						 | 
					@ -80,6 +81,13 @@ from .train import _load_pretrained_tok2vec
 | 
				
			||||||
        "es",
 | 
					        "es",
 | 
				
			||||||
        int,
 | 
					        int,
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
 | 
					    gpu_id=(
 | 
				
			||||||
 | 
					        "Index of GPU to use, e.g. 0. -1 for CPU.",
 | 
				
			||||||
 | 
					        "option",
 | 
				
			||||||
 | 
					        "gpu",
 | 
				
			||||||
 | 
					        int,
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def pretrain(
 | 
					def pretrain(
 | 
				
			||||||
    texts_loc,
 | 
					    texts_loc,
 | 
				
			||||||
| 
						 | 
					@ -104,6 +112,7 @@ def pretrain(
 | 
				
			||||||
    n_save_every=None,
 | 
					    n_save_every=None,
 | 
				
			||||||
    init_tok2vec=None,
 | 
					    init_tok2vec=None,
 | 
				
			||||||
    epoch_start=None,
 | 
					    epoch_start=None,
 | 
				
			||||||
 | 
					    gpu_id=-1,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
 | 
					    Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
 | 
				
			||||||
| 
						 | 
					@ -126,10 +135,9 @@ def pretrain(
 | 
				
			||||||
            config[key] = str(config[key])
 | 
					            config[key] = str(config[key])
 | 
				
			||||||
    msg = Printer()
 | 
					    msg = Printer()
 | 
				
			||||||
    util.fix_random_seed(seed)
 | 
					    util.fix_random_seed(seed)
 | 
				
			||||||
 | 
					    if gpu_id != -1:
 | 
				
			||||||
    has_gpu = prefer_gpu(gpu_id=1)
 | 
					        has_gpu = require_gpu(gpu_id=gpu_id)
 | 
				
			||||||
    msg.info("Using GPU" if has_gpu else "Not using GPU")
 | 
					    msg.info("Using GPU {}".format(gpu_id) if has_gpu else "Not using GPU")
 | 
				
			||||||
 | 
					 | 
				
			||||||
    output_dir = Path(output_dir)
 | 
					    output_dir = Path(output_dir)
 | 
				
			||||||
    if not output_dir.exists():
 | 
					    if not output_dir.exists():
 | 
				
			||||||
        output_dir.mkdir()
 | 
					        output_dir.mkdir()
 | 
				
			||||||
| 
						 | 
					@ -206,7 +214,8 @@ def pretrain(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _save_model(epoch, is_temp=False):
 | 
					    def _save_model(epoch, is_temp=False):
 | 
				
			||||||
        is_temp_str = ".temp" if is_temp else ""
 | 
					        is_temp_str = ".temp" if is_temp else ""
 | 
				
			||||||
        with model.use_params(optimizer.averages):
 | 
					        #with model.use_params(optimizer.averages):
 | 
				
			||||||
 | 
					        if True:
 | 
				
			||||||
            with (output_dir / ("model%d%s.bin" % (epoch, is_temp_str))).open(
 | 
					            with (output_dir / ("model%d%s.bin" % (epoch, is_temp_str))).open(
 | 
				
			||||||
                "wb"
 | 
					                "wb"
 | 
				
			||||||
            ) as file_:
 | 
					            ) as file_:
 | 
				
			||||||
| 
						 | 
					@ -221,6 +230,10 @@ def pretrain(
 | 
				
			||||||
                file_.write(srsly.json_dumps(log) + "\n")
 | 
					                file_.write(srsly.json_dumps(log) + "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    skip_counter = 0
 | 
					    skip_counter = 0
 | 
				
			||||||
 | 
					    min_lr = optimizer.alpha / 3
 | 
				
			||||||
 | 
					    max_lr = optimizer.alpha * 2
 | 
				
			||||||
 | 
					    period = 10000
 | 
				
			||||||
 | 
					    learn_rates = cyclic_triangular_rate(min_lr, max_lr, period)
 | 
				
			||||||
    for epoch in range(epoch_start, n_iter + epoch_start):
 | 
					    for epoch in range(epoch_start, n_iter + epoch_start):
 | 
				
			||||||
        for batch_id, batch in enumerate(
 | 
					        for batch_id, batch in enumerate(
 | 
				
			||||||
            util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
 | 
					            util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
 | 
				
			||||||
| 
						 | 
					@ -232,6 +245,7 @@ def pretrain(
 | 
				
			||||||
                min_length=min_length,
 | 
					                min_length=min_length,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            skip_counter += count
 | 
					            skip_counter += count
 | 
				
			||||||
 | 
					            optimizer.alpha = next(learn_rates)
 | 
				
			||||||
            loss = make_update(
 | 
					            loss = make_update(
 | 
				
			||||||
                model, docs, optimizer, objective=loss_func, drop=dropout
 | 
					                model, docs, optimizer, objective=loss_func, drop=dropout
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user