mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fixing ngram bug (#3953)
* minimal failing example for Issue #3661 * referenced Issue #3661 instead of Issue #3611 * cleanup
This commit is contained in:
		
							parent
							
								
									123929b58b
								
							
						
					
					
						commit
						ed774cb953
					
				
							
								
								
									
										51
									
								
								spacy/tests/regression/test_issue3611.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								spacy/tests/regression/test_issue3611.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,51 @@
 | 
				
			||||||
 | 
					# coding: utf8
 | 
				
			||||||
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import spacy
 | 
				
			||||||
 | 
					from spacy.util import minibatch, compounding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_issue3611():
 | 
				
			||||||
 | 
					    """ Test whether adding n-grams in the textcat works even when n > token length of some docs """
 | 
				
			||||||
 | 
					    unique_classes = ["offensive", "inoffensive"]
 | 
				
			||||||
 | 
					    x_train = ["This is an offensive text",
 | 
				
			||||||
 | 
					               "This is the second offensive text",
 | 
				
			||||||
 | 
					               "inoff"]
 | 
				
			||||||
 | 
					    y_train = ["offensive", "offensive", "inoffensive"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # preparing the data
 | 
				
			||||||
 | 
					    pos_cats = list()
 | 
				
			||||||
 | 
					    for train_instance in y_train:
 | 
				
			||||||
 | 
					        pos_cats.append({label: label == train_instance for label in unique_classes})
 | 
				
			||||||
 | 
					    train_data = list(zip(x_train, [{'cats': cats} for cats in pos_cats]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # set up the spacy model with a text categorizer component
 | 
				
			||||||
 | 
					    nlp = spacy.blank('en')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    textcat = nlp.create_pipe(
 | 
				
			||||||
 | 
					        "textcat",
 | 
				
			||||||
 | 
					        config={
 | 
				
			||||||
 | 
					            "exclusive_classes": True,
 | 
				
			||||||
 | 
					            "architecture": "bow",
 | 
				
			||||||
 | 
					            "ngram_size": 2
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for label in unique_classes:
 | 
				
			||||||
 | 
					        textcat.add_label(label)
 | 
				
			||||||
 | 
					    nlp.add_pipe(textcat, last=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # training the network
 | 
				
			||||||
 | 
					    other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'textcat']
 | 
				
			||||||
 | 
					    with nlp.disable_pipes(*other_pipes):
 | 
				
			||||||
 | 
					        optimizer = nlp.begin_training()
 | 
				
			||||||
 | 
					        for i in range(3):
 | 
				
			||||||
 | 
					            losses = {}
 | 
				
			||||||
 | 
					            batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for batch in batches:
 | 
				
			||||||
 | 
					                texts, annotations = zip(*batch)
 | 
				
			||||||
 | 
					                nlp.update(docs=texts, golds=annotations, sgd=optimizer, drop=0.1, losses=losses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user