欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

matplot画图之plt.scatter函数

时间:2023-06-04

 函数原型:

def scatter( x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs): __ret = gca().scatter( x, y, s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, edgecolors=edgecolors, plotnonfinite=plotnonfinite, **({"data": data} if data is not None else {}), **kwargs) sci(__ret) return __ret

x,y:表示的是大小为(x,y)的数组,绘制散点图的数据点

s:是一个实数或者是一个数组大小为(n,),这个是一个可选的参数。

c:表示的是颜色,默认是蓝色'b',表示的是标记的颜色

marker:表示的是绘制标记的样式,默认的是'o'圆圈,改成'x'则变成字符X。

cmap:Colormap实体或者colormap的名字,cmap当c是一个浮点数数组的时候才使用。如果没有申明就是image.cmap

norm:Normalize实体来将数据亮度转化到0-1之间,只有c是一个浮点数的数组的时候才使用。如果没有申明,就是默认为colors.Normalize。

vmin,vmax:实数,当norm存在的时候忽略。用来进行亮度数据的归一化。

alpha:实数,0-1之间。

linewidths:也就是标记点的长度。

下图是根据csv文件导入的数据所绘制的图像。

# PyTorchimport torchimport torch.nn as nnfrom torch.utils.data import Dataset, DataLoader# For data preprocessimport numpy as npimport csvimport os# For plottingimport matplotlib.pyplot as pltfrom matplotlib.pyplot import figure#下面三个包是新增的from sklearn.model_selection import train_test_splitimport pandas as pdimport pprint as pppd.set_option('display.max_rows', 200) # 200行pd.set_option('display.max_columns', 200) # 200列myseed = 42069 # set a random seed for reproducibilitytorch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsenp.random.seed(myseed)torch.manual_seed(myseed)if torch.cuda.is_available(): torch.cuda.manual_seed_all(myseed)tr_path = 'covid.train.csv' # path to training datatt_path = 'covid.test.csv' # path to testing datadata_tr = pd.read_csv(tr_path) #读取训练数据data_tt = pd.read_csv(tt_path) #读取测试数据#print(data_tt.head(3))#print(data_tr.head(3))#print(data_tr.columns) #查看有多少列特征data_tr.drop(['id'],axis = 1, inplace = True) #由于id列用不到,删除id列data_tt.drop(['id'],axis = 1, inplace = True)cols = list(data_tr.columns) #拿到特征列名称#pp.pprint(data_tr.columns)#pp.pprint(data_tr.info()) #看每列数据类型和大小WI_index = cols.index('WI') # WI列是states one-hot编码最后一列,取值为0或1,后面特征分析时需要把states特征删掉WI_index #wi列索引#one-hot编码。one-hot编码的定义是用N位状态寄存器来对N个状态进行编码。# 比如[0,0.3],(0.3,0.6],(0.6,1],有3个分类值,因此N为3,对应的one-hot编码可以表示为100,010,001。#使用步骤:比如用LR算法做模型,在数据处理过程中,可以先对连续变量进行离散化处理,# 然后对离散化后数据进行one-hot编码,最后放入LR模型中。# 这样可以增强模型的非线性能力。#print(data_tr.iloc[:, 40:].describe()) #从上面可以看出wi 列后面是cli, 所以列索引从40开始, 并查看这些数据分布#print(data_tt.iloc[:, 40:].describe()) #查看测试集数据分布,并和训练集数据分布对比,两者特征之间数据分布差异不是很大plt.scatter(data_tr.loc[:, 'cli'], data_tr.loc[:, 'tested_positive.2']) #肉眼分析cli特征与目标之间相关性plt.show()

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。