To draw on a specific subplot, two indices are needed (row, column), so axes[0,0]
for the first subplot. The error message comes from using ax1 = axes[0]
instead of ax1 = axes[0,0]
.
Now, to create a stacked histogram via ax.hist()
, all the y-data need to be provided at the same time. The code below shows how this can be done starting from the result of groupby
. Also note, that when your values are discrete, it is important to explicitly set the bin boundaries making sure that the values fall precisely between these boundaries. Setting the boundaries at the halves is one way.
Things can be simplified a lot using seaborn's histplot()
. Here is a breakdown of the parameters used:
data=df
the dataframe
y='y'
gives the dataframe column for histogram. Use x=
(instead of y=
) for a vertical histogram.
hue='cat'
gives the dataframe column to create mulitple groups
palette=mycolorsdict
; the palette defines the coloring; there are many ways to assign a palette, one of which is a dictionary on the hue
values
discrete=True
: when working with discrete data, seaborn sets the appropriate bin boundaries
multiple='stack'
creates a stacked histogram, depending on the hue
categories
alpha=1
: default seaborn sets an alpha of 0.75
; optionally this can be changed
ax=axes[0, 1]
: draw on the 2nd subplot of the 1st row
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn-whitegrid')
y = [1, 5, 9, 2, 4, 2, 5, 6, 1]
cat = ['A', 'B', 'B', 'B', 'A', 'B', 'B', 'B', 'B']
df = pd.DataFrame({'y':y, 'cat':cat})
fig, axes = plt.subplots(3, 3, figsize=(20, 10), constrained_layout=True)
fig.suptitle('Histograms')
mycolorsdict = {'A': 'magenta', 'B': 'blue'}
groups = df.groupby(['cat'])
axes[0, 0].hist([batch.y for _, batch in groups],
label=[key for key, _ in groups], color=[mycolorsdict[key] for key, _ in groups], density=False,
edgecolor='black',
cumulative=False, orientation='horizontal', stacked=True, bins=np.arange(0.5, 10))
axes[0, 0].legend()
sns.histplot(data=df, y='y', hue='cat', palette=mycolorsdict, discrete=True, multiple='stack', alpha=1, ax=axes[0, 1])
plt.show()