Python

Entrenamiento, validación y test con Scikit-learn

Entre las herramientas para la selección de modelos de Scikit-learn nos podemos encontrar con la función train_test_split. Una función que nos permite dividir un conjunto de datos en uno de entrenamiento y otro de test. En la bibliografía es habitual encontrar que se tiene que dividir los conjuntos de datos para el entrenamiento de los modelos en tres: entrenamiento, validación y test, no en dos. Actualmente no existe una función que haga esto de forma automática en Scikit-learn. Por eso, en esta entrada vamos a ver cómo se puede hacer para dividir un conjunto de datos en entrenamiento, validación y test con Scikit-learn.

La función train_test_split

El método train_test_split de Scikit-learn nos permite fácilmente dividir un conjunto de datos de una matriz o DataFrame en dos aleatorios con un tamaño dato. Una función que se puede llamar de la siguiente manera:

X_train, X_test, y_train, y_test = train_test_split(X, y)

En donde se ha pasado como parámetro a la función las variables independientes (X) y la variable dependiente (y), obteniéndose como resultados un conjunto de datos para entrenamiento (X_train e y_train) y otro para test (X_test e y_test). En este caso el 75% del todos los registros estarán en el conjunto de entrenamiento y el 25 restante en el de test.

La función admite diferentes propiedades interesantes con las adaptar el funcionamiento del método, entre las que se puede destacar:

  • test_size: el tamaño del conjunto de datos de para test que tiene que ser un valor entre 0 y 1. Una opción complementaria a esa es train_size la que se puede usar en lugar de esta para indicar el tamaño del conjunto de entrenamiento. Usar una u otra es solamente una cuestión de preferencia personal.
  • random_state: un entero con el que se indica la semilla utilizada para la selección de datos. Parámetro que es clave cuando necesitamos que los resultados sean repetibles, por lo que es aconsejable usarlos siempre.
  • stratify: una variable con la que se puede indicar como hacer una estratificación de los datos.

División de los datos en entrenamiento, validación y test

Como se ha visto no existe una forma directa de dividir el conjunto de datos en tres. Pero es algo que se puede hacer fácilmente anidando los valores. Por ejemplo si tenemos tres valores que suman la unidad (train_size, validation_size, test_size) es posible dividir este un conjunto de datos entre de la siguiente manera.

validation = validation_size / (test_size + validation_size)
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, train_size=train_size)
x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, train_size=validation)

En donde en primer lugar se ha calculado en porcentaje del conjunto de datos de test más validación que corresponde a validación. Así en una primera división se divide el conjunto original en uno de entrenamiento y otro de test y validación. Una vez hecho esto lo que se hace es dividir el conjunto en dos: uno de test y otro de validación.

Conclusiones

Hoy hemos visto un pequeño truco para dividir un conjunto de datos en tres con train_test_split. Pudiendo de esta manera obtener fácilmente conjuntos de entrenamiento, validación y test con Scikit-learn

Imagen de Michael Drummond en Pixabay

¿Te ha parecido de utilidad el contenido?

Daniel Rodríguez

Share
Published by
Daniel Rodríguez

Recent Posts

Síndrome del objeto brillante en ciencia de datos: el error simétrico a los costes hundidos

Hace poco publiqué una entrada en la que trataba de un sesgo bien documentado: aferrarse…

4 días ago

De la Regresión Logística al Scorecard: La Transformación Matemática

En un entrada previa explicamos qué son el WOE y el IV y por qué…

6 días ago

Analytics Lane lanza la versión 1.1 del laboratorio con nuevas suites de CLV y Scoring

Seguimos evolucionando el laboratorio de Analytics Lane y hoy lanzamos la versión 1.1, disponible en:…

7 días ago

Interés compuesto: la fuerza que multiplica tu dinero (y los errores que la anulan)

“El interés compuesto es la octava maravilla del mundo. El que lo entiende lo gana…

2 semanas ago

Cómo comparar datos con barras en Matplotlib: agrupadas, apiladas y porcentuales

Tienes los datos de ventas de tres productos en dos años distintos y quieres saber…

2 semanas ago

Costes hundidos en ciencia de datos: cuándo mantener un modelo y cuándo migrar

Imagina la situación. Tu equipo lleva tres años con un modelo en producción. No es…

3 semanas ago

This website uses cookies.