Train BERT Document Classifier
This example shows how to train a BERT neural network for document classification.
A Bidirectional Encoder Representations from Transformer (BERT) model is a transformer neural network that can be fine-tuned for natural language processing tasks such as document classification and sentiment analysis. The network uses attention layers to analyze text in context and capture long-range dependencies between words.
This example fine-tunes a pretrained BERT-Base neural network to predict the category of factory reports using text descriptions.
Load Training Data
Read the training data from the factoryReports
CSV file. The file contains factory reports, including a text description and categorical label for each report.
filename = "factoryReports.csv"; data = readtable(filename,TextType="string"); head(data)
Description Category Urgency Resolution Cost _____________________________________________________________________ ____________________ ________ ____________________ _____ "Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45 "Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35 "There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200 "Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352 "Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55 "Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371 "A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441 "Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
Convert the labels in the Category
column of the table to categorical values and view the distribution of the classes in the data using a histogram.
data.Category = categorical(data.Category); figure histogram(data.Category) xlabel("Class") ylabel("Frequency") title("Class Distribution")
View the number of classes.
classNames = categories(data.Category); numClasses = numel(classNames)
numClasses = 4
Partition the data into a training set and a test set. Specify the holdout percentage as 10%.
cvp = cvpartition(data.Category,Holdout=0.1); dataTrain = data(cvp.training,:); dataTest = data(cvp.test,:);
Extract the text data and labels from the tables.
textDataTrain = dataTrain.Description; textDataTest = dataTest.Description; TTrain = dataTrain.Category; TTest = dataTest.Category;
Load Pretrained BERT Document Classifier
Load a pretrained BERT-Base document classifier using the bertDocumentClassifier
function. If the Text Analytics Toolbox™ Model for BERT-Base Network support package is not installed, then the function provides a link to the required support package in the Add-On Explorer. To install the support package, click the link, and then click Install.
mdl = bertDocumentClassifier(ClassNames=classNames)
mdl = bertDocumentClassifier with properties: Network: [1×1 dlnetwork] Tokenizer: [1×1 bertTokenizer] ClassNames: ["Electronic Failure" "Leak" "Mechanical Failure" "Software Failure"]
Specify Training Options
Specify the training options. Choosing among training options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.
Train using the Adam optimizer.
Train for eight epochs.
For fine-tuning, lower the learning rate. Train using a learning rate of 0.0001.
Shuffle the data every epoch.
Monitor the training progress in a plot and monitor the accuracy metric.
Disable the verbose output.
options = trainingOptions("adam", ... MaxEpochs=8, ... InitialLearnRate=1e-4, ... Shuffle="every-epoch", ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);
Train Neural Network
Train the neural network using the trainBERTDocumentClassifier
function. By default, the trainBERTDocumentClassifier
function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainBERTDocumentClassifier
function uses the CPU. To specify the execution environment, use the ExecutionEnvironment
training option.
mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);
Test Neural Network
Make predictions using the test data.
YTest = classify(mdl,textDataTest);
Visualize the predictions in a confusion matrix.
figure confusionchart(TTest,YTest)
Calculate the classification accuracy of the test predictions.
accuracy = mean(TTest == YTest)
accuracy = 0.9375
Make Predictions Using New Data
Classify the event type of new factory reports. Create a string array containing the new factory reports.
strNew = [ "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."]; labelsNew = classify(mdl,strNew)
labelsNew = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
See Also
classify
| trainBERTDocumentClassifier
| dlnetwork
(Deep Learning Toolbox) | trainingOptions
(Deep Learning Toolbox)