11import pytest
22
33from pandas import DataFrame
4+ import pandas as pd
45from sklearn .datasets import load_iris
56from sklearn .pipeline import Pipeline
67from sklearn .svm import SVC
8+ from sklearn .feature_extraction .text import CountVectorizer
79import numpy as np
810
911from sklearn_pandas import (
@@ -27,6 +29,11 @@ def iris_dataframe():
2729 )
2830
2931
32+ @pytest .fixture
33+ def cars_dataframe ():
34+ return pd .read_csv ("tests/test_data/cars.csv.gz" )
35+
36+
3037def test_with_iris_dataframe (iris_dataframe ):
3138 pipeline = Pipeline ([
3239 ("preprocess" , DataFrameMapper ([
@@ -42,3 +49,16 @@ def test_with_iris_dataframe(iris_dataframe):
4249 scores = cross_val_score (pipeline , data , labels )
4350 assert scores .mean () > 0.96
4451 assert (scores .std () * 2 ) < 0.04
52+
53+
54+ def test_with_car_dataframe (cars_dataframe ):
55+ pipeline = Pipeline ([
56+ ("preprocess" , DataFrameMapper ([
57+ ("description" , CountVectorizer ()),
58+ ])),
59+ ("classify" , SVC (kernel = 'linear' ))
60+ ])
61+ data = cars_dataframe .drop ("model" , axis = 1 )
62+ labels = cars_dataframe ["model" ]
63+ scores = cross_val_score (pipeline , data , labels )
64+ assert scores .mean () > 0.30
0 commit comments