LogisticRegression in MLLib (PySpark + numpy+matplotlib可视化)


参考'LogisticRegression in MLLib' (http://www.cnblogs.com/luweiseu/p/7809521.html)
通过pySpark MLlib训练logistic模型,再利用Matplotlib作图画出分类边界。

from pyspark.sql import Row
from pyspark.sql import HiveContext
import pyspark
from IPython.display import display
import matplotlib
import matplotlib.pyplot as plt

import os
os.environ['SPARK_HOME'] ="C:\\Users\\software\\spark-2.1.0-bin-hadoop2.7"

%matplotlib inline 

sc = pyspark.SparkContext(master='local').getOrCreate()
sqlContext = HiveContext(sc)

# get data
irisData = sc.textFile("iris.txt")


from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.linalg import Vectors
from pyspark.mllib.classification import LogisticRegressionWithLBFGS


def toLabeledPoint(line):
    linesp = line.split()
    return LabeledPoint(int(linesp[2]), Vectors.dense(float(linesp[0]), float(linesp[1])))

data = irisData.map(toLabeledPoint)

#Split data into training (60%) and test (40%).
splits = data.randomSplit([0.6, 0.4],seed=11)
training = splits[0].cache()
test = splits[1]

trainer = LogisticRegressionWithLBFGS()

model = trainer.train(training,intercept=True,numClasses=3)

# testdata
def predicTest(lp):
    label=lp.label
    features=lp.features
    prediction = model.predict(features)
    return (float(prediction), label)
predictionAndLabels = test.map(predicTest)


from pyspark.mllib.evaluation import MulticlassMetrics

#accuracy
metrics = MulticlassMetrics(predictionAndLabels)
accuracy = metrics.accuracy
accuracy

# plot boundary
import numpy as np

## meshgrid
x0, x1 = np.meshgrid(
        np.linspace(0, 8, 500).reshape(-1, 1),
        np.linspace(0, 3.5, 200).reshape(-1, 1),
    )
X_new = np.c_[x0.ravel(), x1.ravel()]

## predict
y_predict = [model.predict(Vectors.dense(X_new_i)) for X_new_i in X_new]

y = data.map(lambda d: d.label).collect()
X = data.map(lambda d: [d.features[0], d.features[1]]).collect()

y=np.array(y)
X=np.array(X)

## draw
zz = np.array(y_predict).reshape(x0.shape)

plt.figure(figsize=(10, 4))
plt.plot(X[y==2, 0], X[y==2, 1], "g^", label="Iris-Virginica")
plt.plot(X[y==1, 0], X[y==1, 1], "bs", label="Iris-Versicolor")
plt.plot(X[y==0, 0], X[y==0, 1], "yo", label="Iris-Setosa")

from matplotlib.colors import ListedColormap
custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])

plt.contourf(x0, x1, zz, cmap=custom_cmap, linewidth=5)
# plt.clabel(contour, inline=1, fontsize=12)
plt.xlabel("Petal length", fontsize=14)
plt.ylabel("Petal width", fontsize=14)
plt.legend(loc="center left", fontsize=14)
plt.axis([0, 7, 0, 3.5])
plt.show()

最终结果:

智能推荐

注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
© 2014-2019 ITdaan.com 粤ICP备14056181号  

赞助商广告