Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

2023-12-23 04:44

本文主要是介绍Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

我们知道torch.meshgrid()函数的功能是生成网格,可以用于生成坐标;

在numpy中也有一样的函数np.meshgrid(),但是用法不太一样,我们直接上代码进行解释。

1、两者在用法上的区别

比如:我要生成下图的xy坐标点,看下两者的实现方式:

在这里插入图片描述

np.meshgrid()

>>> import numpy as np
>>> w, h = 4, 2
# 注意,此时输入的是由w和h生成的一维数组
#      此时输出的是网格x的坐标grid_x以及网格y的坐标grid_y
>>> grid_x, grid_y  = np.meshgrid(np.arange(w), np.arange(h)) >>> grid_x
array([[0, 1, 2, 3],  [0, 1, 2, 3]])
>>> grid_y
array([[0, 0, 0, 0],[1, 1, 1, 1]])

torch.meshgrid()

>>> import torch
# 注意,此时输入的是由h和w生成的一维数组(和numpy中的输入顺序相反)
#      此时输出的是网格y的坐标grid_y以及网格x的坐标grid_x(和numpy中的输出顺序相反)
>>> grid_y, grid_x =  torch.meshgrid(
...         torch.arange(h),
...         torch.arange(w)
...     )
>>> grid_x
tensor([[0, 1, 2, 3],[0, 1, 2, 3]])
>>> grid_y
tensor([[0, 0, 0, 0],[1, 1, 1, 1]])

2、应用案例

2.1 利用np.meshgrid()来画决策边界

我们可以利用np.meshgrid()来画等高线图

# 等高线图
import numpy as np
import matplotlib.pyplot as plt# 模拟海拔高度
def fz(x, y):z = (1 -x / 2 + x**5 + y**3) * np.exp(-x**2-y**2)return zw = np.linspace(-4, 4, 100)
h = np.linspace(-2, 2, 100)grid_x, grid_y = np.meshgrid(w, h)
z = fz(grid_x, grid_y)plt.figure('Contour Chart',facecolor='lightgray')
plt.title('contour',fontsize=16)
plt.grid(linestyle=':')cntr = plt.contour(grid_x, # 网格坐标矩阵的x坐标(2维数组)grid_y, # 网格坐标矩阵的y坐标(2维数组)z,      # 网格坐标矩阵的z坐标(2维数组)8,      # 等高线绘制8部分colors = 'black', # 等高线图颜色linewidths = 0.5 # 等高线图线宽
)
# 设置标签
plt.clabel(cntr, inline_spacing = 1, fmt='%.2f', fontsize=10)
# 填充颜色  大的是红色  小的是蓝色
plt.contourf(grid_x, grid_y, z, 8, cmap='jet')plt.legend()
plt.show()

在这里插入图片描述

我们可以利用np.meshgrid()来画决策边界。

from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import numpy as npfrom sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC# 使用sklearn自带的moon数据
X, y = make_moons(n_samples=100,noise=0.15,random_state=42)# 绘制生成的数据
def plot_dataset(X,y,axis):plt.plot(X[:,0][y == 0],X[:,1][y == 0],'bs')plt.plot(X[:,0][y == 1],X[:,1][y == 1],'go')plt.axis(axis)plt.grid(True,which='both')# 画出决策边界
def plot_pred(clf,axes):w = np.linspace(axes[0],axes[1], 100)h = np.linspace(axes[2],axes[3], 100)grid_x, grid_y = np.meshgrid(w, h)# grid_x 和 grid_y 被拉成一列,然后拼接成10000行2列的矩阵,表示所有点grid_xy = np.c_[grid_x.ravel(), grid_y.ravel()]# 二维点集才可以用来预测y_pred = clf.predict(grid_xy).reshape(grid_x.shape)# 等高线plt.contourf(grid_x, grid_y,y_pred,alpha=0.2)ploy_kernel_svm_clf = Pipeline(steps=[("scaler",StandardScaler()),("svm_clf",SVC(kernel='poly', degree=3, coef0=1, C=5))]
)ploy_kernel_svm_clf.fit(X,y)plot_pred(ploy_kernel_svm_clf,[-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.show()

在这里插入图片描述

2.2 利用torch.meshgrid()生成网格所有坐标的矩阵

在目标检测YOLO中将图像划分为单元网格的部分就用到了torch.meshgrid()函数。

import torch
import numpy as npdef create_grid(input_size, stride=32):# 1、获取原始图像的w和hw, h = input_size, input_size# 2、获取经过32倍下采样后的feature mapws, hs = w // stride, h // stride# 3、生成网格的y坐标和x坐标grid_y , grid_x = torch.meshgrid([torch.arange(hs),torch.arange(ws)])# 4、将grid_x和grid_y进行拼接,拼接后的维度为【H, W, 2】grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()# 【H, W, 2】 -> 【HW, 2】grid_xy = grid_xy.view(-1, 2)return grid_xyif __name__ == '__main__':print(create_grid(input_size=32*4))
# 生成网格所有坐标的矩阵
tensor([[0., 0.],[1., 0.],[2., 0.],[3., 0.],[0., 1.],[1., 1.],[2., 1.],[3., 1.],[0., 2.],[1., 2.],[2., 2.],[3., 2.],[0., 3.],[1., 3.],[2., 3.],[3., 3.]])

这篇关于Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/526760

相关文章

Linux中shell解析脚本的通配符、元字符、转义符说明

《Linux中shell解析脚本的通配符、元字符、转义符说明》:本文主要介绍shell通配符、元字符、转义符以及shell解析脚本的过程,通配符用于路径扩展,元字符用于多命令分割,转义符用于将特殊... 目录一、linux shell通配符(wildcard)二、shell元字符(特殊字符 Meta)三、s

Java 字符数组转字符串的常用方法

《Java字符数组转字符串的常用方法》文章总结了在Java中将字符数组转换为字符串的几种常用方法,包括使用String构造函数、String.valueOf()方法、StringBuilder以及A... 目录1. 使用String构造函数1.1 基本转换方法1.2 注意事项2. 使用String.valu

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

VUE动态绑定class类的三种常用方式及适用场景详解

《VUE动态绑定class类的三种常用方式及适用场景详解》文章介绍了在实际开发中动态绑定class的三种常见情况及其解决方案,包括根据不同的返回值渲染不同的class样式、给模块添加基础样式以及根据设... 目录前言1.动态选择class样式(对象添加:情景一)2.动态添加一个class样式(字符串添加:情

使用Python实现批量访问URL并解析XML响应功能

《使用Python实现批量访问URL并解析XML响应功能》在现代Web开发和数据抓取中,批量访问URL并解析响应内容是一个常见的需求,本文将详细介绍如何使用Python实现批量访问URL并解析XML响... 目录引言1. 背景与需求2. 工具方法实现2.1 单URL访问与解析代码实现代码说明2.2 示例调用

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

SpringCloud配置动态更新原理解析

《SpringCloud配置动态更新原理解析》在微服务架构的浩瀚星海中,服务配置的动态更新如同魔法一般,能够让应用在不重启的情况下,实时响应配置的变更,SpringCloud作为微服务架构中的佼佼者,... 目录一、SpringBoot、Cloud配置的读取二、SpringCloud配置动态刷新三、更新@R

使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)

《使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)》在现代软件开发中,处理JSON数据是一项非常常见的任务,无论是从API接口获取数据,还是将数据存储为JSON格式,解析... 目录1. 背景介绍1.1 jsON简介1.2 实际案例2. 准备工作2.1 环境搭建2.1.1 添加

Oracle的to_date()函数详解

《Oracle的to_date()函数详解》Oracle的to_date()函数用于日期格式转换,需要注意Oracle中不区分大小写的MM和mm格式代码,应使用mi代替分钟,此外,Oracle还支持毫... 目录oracle的to_date()函数一.在使用Oracle的to_date函数来做日期转换二.日

Java 枚举的常用技巧汇总

《Java枚举的常用技巧汇总》在Java中,枚举类型是一种特殊的数据类型,允许定义一组固定的常量,默认情况下,toString方法返回枚举常量的名称,本文提供了一个完整的代码示例,展示了如何在Jav... 目录一、枚举的基本概念1. 什么是枚举?2. 基本枚举示例3. 枚举的优势二、枚举的高级用法1. 枚举