Training Transformer Networks in Scikit-Learn?!

  • Hui Wen GohHui Wen Goh

Have you ever wanted to use handy scikit-learn functionalities with your neural networks, but could not because TensorFlow models are not compatible with the scikit-learn API? Here we introduce KerasWrapperModel and KerasWrapperSequential, two one-line wrappers for TensorFlow/Keras models that enable you to use TensorFlow models within scikit-learn workflows that include features like Pipeline, GridSearch and more.

Transformers are extremely popular for modeling text nowadays with particular Transformers like GPT-4, ChatGPT, BARD, PaLM, FLAN excelling for conversational AI and other Transformers like T5 and Bert excelling for text classification. Scikit-learn offers a broadly useful suite of features for classifier models, but these are hard to use with Transformers. However not if you use KerasWrapperModel or KerasWrapperSequential, which only require changing one line of code to make your existing Tensorflow/Keras model compatible with scikit-learn’s rich ecosystem!

Example of replacing existing keras.Sequential code with KerasWrapperSequential

To demonstrate KerasWrapperModel, we will train a classifier (fine-tuning a pretrained Bert model) to classify positive vs. negative text reviews of products via the steps below:

  • Tokenize our text data to be suitable for a pretrained Bert Transformer model from HuggingFace.
  • Define a (pretrained) Transformer network in Keras code and then swap out one line (keras.Model -> KerasWrapperModel) to make this model compatible with scikit-learn.
  • Conduct a grid search to find the best parameters for the model using sklearn’s GridSearchCV
  • Using the optimal parameters found, train the classifier to classify the reviews
  • Use cleanlab’s CleanLearning to train a more robust version of the same model, which again just involves one extra line of code. This allows you to automatically remove label issues in the dataset, and then retrain the same classifier to get improved predictions.

To run the code demonstrated in this article yourself, check out the full notebook.

Load and Preprocess the Data

Let’s take a look at the Amazon Reviews text dataset we are using for this demonstration. Each example in our dataset is a magazine review obtained from Amazon that have been classified into two categories: positive (1) and negative (0).

You can access and download the train set of this dataset here and the test set here.

Here is a sample label and review from the train set:

Example Label: 1
Example Text: Excellent product! I love reading through the magazine and learning about the cool new products out there and the cool programs!

Before we train our classifier, we need to transform our text data into a format suitable as an input for a neural network.

Here, we tokenize the text to be suitable for a pretrained Bert transformer. (click to see code)
MODEL_NAME = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_input = tokenizer(

test_input = tokenizer(

The pretrained bert tokenizer outputs a dictionary containing input IDs, token type IDs and attention masks for each example, each containing separate pieces of information about the text. The input IDs map each token to a specific ID specified by the pretrained bert model, the token type IDs are used to differentiate between different token types in question answering tasks (and are not relevate for our current classification task), where the attention mask indicates which tokens are contextual tokens vs padding tokens.

For our current task, the input IDs already contain all the information needed to represent the text data, hence we will extract the input IDs from the dictionary returned by the tokenizer, and convert out input data into numpy arrays that are suitable for inputs into scikit-learn functions.

train_input_ids = np.array(train_input["input_ids"])
train_labels = np.array(train_data['label'])

test_input_ids = np.array(test_input["input_ids"])
test_labels = np.array(test_data['label'])

Define Neural Network Model

The next step is to define the Keras function that builds the Transformer model we would normally use to classify the Amazon reviews.

Here we are fine-tuning a pretrained Bert Transformer for classification, which requires the input_id, token_type_ids and attention_mask as inputs. Our model only takes the input_ids as an input however we can internally obtain the information required to construct the token_type_ids and attention_mask arrays to pass into the bert model.

def build_model(model_name:str, max_len:int, n_classes:int):
    # define input ids, token type ids and attention mask
    input_ids = tf.keras.layers.Input(shape=(max_len,), dtype='int32', name='input_ids')
    token_type_ids = tf.keras.layers.Lambda(lambda x: x * 0, name='token_type_ids')(input_ids)
    attention_mask = tf.keras.layers.Lambda(lambda x: tf.cast(x != 0, tf.int32), name="attention_mask")(input_ids)

    # get bert main layer and add it to the NN, passing in inputs
    bert_layer = TFAutoModel.from_pretrained(model_name)
    layer = bert_layer(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[1]
    output_layer = tf.keras.layers.Dense(n_classes, activation='softmax')(layer)

    # model instance
    model = tf.keras.Model(inputs=[input_ids], outputs=output_layer)
    return model

Ordinarily you would instantiate this model using keras.Model with the above build_model function, but here we simply replace this with KerasWrapperModel instead. The resulting model object is now a Keras model that is scikit-learn compatible!

model = KerasWrapperModel(  # this would typically be: keras.Model( 
        "model_name": MODEL_NAME,
        "max_len": 30,
        "n_classes": 2,
    compile_kwargs= {

We can check out the summary of our wrapped Keras neural network by calling the summary() method, the same way we would for a regular Keras model.

Model: "model"
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_ids (InputLayer)         [(None, 30)]         0           []                               
 attention_mask (Lambda)        (None, 30)           0           ['input_ids[0][0]']              
 token_type_ids (Lambda)        (None, 30)           0           ['input_ids[0][0]']              
 tf_bert_model (TFBertModel)    TFBaseModelOutputWi  109482240   ['input_ids[0][0]',              
                                thPoolingAndCrossAt               'attention_mask[0][0]',         
                                tentions(last_hidde               'token_type_ids[0][0]']         
                                n_state=(None, 30,                                                
                                e, 768),                                                          
                                ne, hidden_states=N                                               
                                one, attentions=Non                                               
                                e, cross_attentions                                               
 dense (Dense)                  (None, 2)            1538        ['tf_bert_model[0][1]']          
Total params: 109,483,778
Trainable params: 109,483,778
Non-trainable params: 0

Hyperparameter Tuning with GridSearch

Now that our Keras model is sklearn-compatible, we can use GridSearchCV to find the best hyperparameters for this classifier. The code to implement the grid search is the same as you would for any other sklearn model, and is shown below:

params = {
    'batch_size': [32, 64],
    'epochs': [5, 10],

gs = GridSearchCV(model, params, refit=False, cv=3, verbose=2, scoring='accuracy', error_score='raise'), train_labels)

print("best score: {:.3f}, best params: {}".format(gs.best_score_, gs.best_params_))
best score: 0.834, best params: {'batch_size': 64, 'epochs': 5}

After the grid search, we obtain the best parameters for the model and can use these parameters to train the final model. We then measure the accuracy of the model using a held-out test set., train_labels, epochs=5, batch_size=64)

model_preds = model.predict(test_input_ids)
model_accuracy = np.mean(model_preds == test_labels)

print(f"Base model accuracy = {model_accuracy}")
Base model accuracy = 0.917

You can apply many other useful scikit-learn functionalities with this model now too.

Train a more robust model using CleanLearning

We’ve already obtain a good model by hyperparameter tuning above. However there are other data-centric techniques we can use to improve the classifier’s performance that do not alter the model’s architecture or hyperparameters in any way.

Instead, we would like to focus on the data quality of our training data. Most datasets contain label errors, which will negatively impact the ability of the model to learn and hence its performance. Here, we demonstrate how to use CleanLearning to automatically identify label errors and train a more robust model.

CleanLearning is a wrapper than can be easily applied to any scikit-learn compatible model (which our model above is because we wrapped it with the scikit-learn compatible KerasWrapperModel!). Once wrapped, the resulting model can still be used in the same manner as a regular sklearn model, but it will now train more robustly if the data have noisy labels.

Here, we wrap the our Keras model in a CleanLearning object and call fit() to train the model:

cl = CleanLearning(clf=model, cv_n_folds=3, verbose=True), train_labels, clf_kwargs={"epochs": 5,  "batch_size": 64})  

cleanlab’s CleanLearning will automatically identify the label issues in the dataset and remove them, before training the final model on the remaining (clean) subset of the data, which will produce a more robust model.

Next, we can check the performance of our newly trained CleanLearning model by measuring its accuracy on the same test set as above:

cl_preds = model.predict(test_input_ids)
cl_accuracy = np.mean(cl_preds == test_labels)

print(f"CleanLearning model accuracy = {cl_accuracy}")
CleanLearning model accuracy = 0.945

We see that the test accuracy has improved after filtering out detected label issues from the original dataset.

We can also check out the label issues identified by cleanlab by calling the get_label_issues() method on the CleanLearning object. Here, we will print the index of the top 10 issues and take a closer look at a few of them:

label_issues = cl.get_label_issues()
lowest_quality_issues = label_issues["label_quality"].argsort()[:10].to_numpy()

print(f"index of the top 10 most likely errors: \n {lowest_quality_issues}")
index of the top 10 most likely errors: 
[3477 4560 4516 3731 1003 4330 2997 2689 3839 1075]

Let’s see if cleanlab correctly identified these label errors!

i = 3477  # first example in the list above
print(f"Example Label: {train_data.iloc[i]['label']}")
print(f"Example Text: {train_data.iloc[i]['review_text']}")
Example Label: 0
Example Text: Very satisfied. Love the magazine.l

This is clearly a positive review which has been mislabeled as a negative review.

i = 4560
print(f"Example Label: {train_data.iloc[i]['label']}")
print(f"Example Text: {train_data.iloc[i]['review_text']}")
Example Label: 0
Example Text: I rec'd the first issue exactly one month from the day that I signed up
for it. Excellent service. Very pleasantly surprised!!

Similarly, this is also a positive review which has been correctly identified by cleanlab to be a mislabeled example.


We’ve demonstrated how handy KerasWrapperModel is to make any TensorFlow/Keras model compatible with scikit-learn. While we only demonstrated the use of this classifier with GridSearchCV, it is also compatible with a ton of other scikit-learn functionality such as Pipelines and more.

By making your neural network sklearn-compatible, you can also easily use CleanLearning to identify label issues in your dataset and train a more robust version of the same model!

If you’re interested in seeing other examples using KerasWrapper, check out: this tutorial which wraps a Sequential Keras model, and this example notebook which shows how to pass TensorFlow datasets to the sklearn-compatible neural network model as inputs.

Other resources to learn more

Join our community of scientists/engineers to ask questions, see how others used this KerasWrapper, and help build the future of open-source Data-Centric AI: Cleanlab Slack Community