Skip to content

一元线性回归的代码问题 #195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pineconebig opened this issue Jan 6, 2025 · 1 comment
Open

一元线性回归的代码问题 #195

pineconebig opened this issue Jan 6, 2025 · 1 comment

Comments

@pineconebig
Copy link

大佬您好,Book1_Ch06_Python_Codes的Bk1_Ch06_12.ipynb中 #一元线性回归 slope, intercept = statistics.linear_regression(x_data, y_data 没法线性回归,我不知道是不是因为这个statistics模块更新了,删掉了这个函数

@pineconebig
Copy link
Author

pineconebig commented Jan 6, 2025

下面是我用sklearn写的,
############# 线性回归 #############
import random
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression

产生数据

num = 50
random.seed(0) #seed(0)调试、复现结果或者需要稳定测试的场景,因为它保证了每次执行的随机操作是相同的。
x_data = [random.uniform(0, 10) for _ in range(num)]

噪音

noise = [random.gauss(0,1) for _ in range(num)]
y_data = [0.5 * x_data[idx] + 1 + noise[idx] for idx in range(num)]

绘制散点图

fig, ax = plt.subplots()
ax.scatter(x_data, y_data,label='数据点')
ax.set_xlabel('x'); ax.set_ylabel('y')
ax.set_aspect('equal', adjustable='box') #设置坐标轴的纵横比为相等,即 X 轴和 Y 轴的单位长度相同。adjustable='box' 表示如果图形区域的大小发生变化,坐标轴的比例也会自动调整。
ax.set_xlim(0,10); ax.set_ylim(-2,8)
ax.grid()
ax.legend(fontsize=12)

线性回归

x = np.array(x_data).reshape(-1, 1) # 将 x_data 转换为二维数组,因为 scikit-learn 需要二维输入,每一行是一个样本,每列是一个特征
y = np.array(y_data) # 目标变量
model = LinearRegression() # 创建线性回归模型
model.fit(x, y) # 拟合模型
slope = model.coef_[0] # 获取回归线的斜率和截距
intercept = model.intercept_

绘图

fig, ax = plt.subplots()
ax.scatter(x_data, y_data, label='数据点')
ax.plot(x_data, model.predict(X), color='red', label=f'线性回归: y = {slope:.2f}x + {intercept:.2f}') # 绘制回归线
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal', adjustable='box') # 设置坐标轴的纵横比为相等
ax.set_xlim(0, 10)
ax.set_ylim(-2, 8)
ax.grid()
ax.legend(fontsize=12)
plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant