我有点困惑这段代码是如何工作的:

fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()

在这种情况下,无花果轴是如何工作的?它能做什么?

还有,为什么这不能做同样的事情:

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

当前回答

将坐标轴数组转换为1D

Generating subplots with plt.subplots(nrows, ncols), where both nrows and ncols is greater than 1, returns a nested array of <AxesSubplot:> objects. It’s not necessary to flatten axes in cases where either nrows=1 or ncols=1, because axes will already be 1 dimensional, which is a result of the default parameter squeeze=True The easiest way to access the objects, is to convert the array to 1 dimension with .ravel(), .flatten(), or .flat. .ravel vs. .flatten flatten always returns a copy. ravel returns a view of the original array whenever possible. Once the array of axes is converted to 1-d, there are a number of ways to plot. This answer is relevant to seaborn axes-level plots, which have the ax= parameter (e.g. sns.barplot(…, ax=ax[0]). seaborn is a high-level API for matplotlib. See Figure-level vs. axes-level functions and seaborn is not plotting within defined subplots

import matplotlib.pyplot as plt
import numpy as np  # sample data only

# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]

# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)

# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)

# convert the array to 1 dimension
axes = axes.ravel()

# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
      dtype=object)

遍历扁平数组 如果子图比数据多,这将导致IndexError: list索引超出范围 尝试选择3。或者选择坐标轴的子集(例如坐标轴[:-2])

for i, ax in enumerate(axes):
    ax.plot(x_data[i], y_data[i])

按索引访问每个轴

axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])

索引数据和坐标轴

for i in range(len(x_data)):
    axes[i].plot(x_data[i], y_data[i])

压缩轴和数据,然后遍历元组列表。

for ax, x, y in zip(axes, x_data, y_data):
    ax.plot(x, y)

Ouput


An option is to assign each axes to a variable, fig, (ax1, ax2, ax3) = plt.subplots(1, 3). However, as written, this only works in cases with either nrows=1 or ncols=1. This is based on the shape of the array returned by plt.subplots, and quickly becomes cumbersome. fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) for a 2 x 2 array. This option is most useful for two subplots (e.g.: fig, (ax1, ax2) = plt.subplots(1, 2) or fig, (ax1, ax2) = plt.subplots(2, 1)). For more subplots, it's more efficient to flatten and iterate through the array of axes.

其他回答

将坐标轴数组转换为1D

Generating subplots with plt.subplots(nrows, ncols), where both nrows and ncols is greater than 1, returns a nested array of <AxesSubplot:> objects. It’s not necessary to flatten axes in cases where either nrows=1 or ncols=1, because axes will already be 1 dimensional, which is a result of the default parameter squeeze=True The easiest way to access the objects, is to convert the array to 1 dimension with .ravel(), .flatten(), or .flat. .ravel vs. .flatten flatten always returns a copy. ravel returns a view of the original array whenever possible. Once the array of axes is converted to 1-d, there are a number of ways to plot. This answer is relevant to seaborn axes-level plots, which have the ax= parameter (e.g. sns.barplot(…, ax=ax[0]). seaborn is a high-level API for matplotlib. See Figure-level vs. axes-level functions and seaborn is not plotting within defined subplots

import matplotlib.pyplot as plt
import numpy as np  # sample data only

# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]

# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)

# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)

# convert the array to 1 dimension
axes = axes.ravel()

# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
      dtype=object)

遍历扁平数组 如果子图比数据多,这将导致IndexError: list索引超出范围 尝试选择3。或者选择坐标轴的子集(例如坐标轴[:-2])

for i, ax in enumerate(axes):
    ax.plot(x_data[i], y_data[i])

按索引访问每个轴

axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])

索引数据和坐标轴

for i in range(len(x_data)):
    axes[i].plot(x_data[i], y_data[i])

压缩轴和数据,然后遍历元组列表。

for ax, x, y in zip(axes, x_data, y_data):
    ax.plot(x, y)

Ouput


An option is to assign each axes to a variable, fig, (ax1, ax2, ax3) = plt.subplots(1, 3). However, as written, this only works in cases with either nrows=1 or ncols=1. This is based on the shape of the array returned by plt.subplots, and quickly becomes cumbersome. fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) for a 2 x 2 array. This option is most useful for two subplots (e.g.: fig, (ax1, ax2) = plt.subplots(1, 2) or fig, (ax1, ax2) = plt.subplots(2, 1)). For more subplots, it's more efficient to flatten and iterate through the array of axes.

还可以在subplots调用中解包坐标轴 并设置是否要在子图之间共享x轴和y轴

是这样的:

import matplotlib.pyplot as plt
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1, ax2, ax3, ax4 = axes.flatten()

ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()

如果你真的想使用循环,请执行以下操作:

def plot(data):
    fig = plt.figure(figsize=(100, 100))
    for idx, k in enumerate(data.keys(), 1):
        x, y = data[k].keys(), data[k].values
        plt.subplot(63, 10, idx)
        plt.bar(x, y)  
    plt.show()

你可以使用以下语句:

import numpy as np
import matplotlib.pyplot as plt

fig, _ = plt.subplots(nrows=2, ncols=2)

for i, ax in enumerate(fig.axes):
  ax.plot(np.sin(np.linspace(0,2*np.pi,100) + np.pi/2*i))

或者用第二个变量plt。次要情节的回报:

fig, ax_mat = plt.subplots(nrows=2, ncols=2)
for i, ax in enumerate(ax_mat.flatten()):
    ...

Ax_mat是一个轴的矩阵。它的形状是nrows x ncols。

有几种方法可以做到这一点。subplots方法创建图形和子图,然后存储在ax数组中。例如:

import matplotlib.pyplot as plt

x = range(10)
y = range(10)

fig, ax = plt.subplots(nrows=2, ncols=2)

for row in ax:
    for col in row:
        col.plot(x, y)

plt.show()

然而,像这样的东西也可以工作,虽然它不是那么“干净”,因为你创建了一个带有子图的图形,然后在它们上面添加:

fig = plt.figure()

plt.subplot(2, 2, 1)
plt.plot(x, y)

plt.subplot(2, 2, 2)
plt.plot(x, y)

plt.subplot(2, 2, 3)
plt.plot(x, y)

plt.subplot(2, 2, 4)
plt.plot(x, y)

plt.show()