Scatter plots are frequently used in data science and machine learning. It’s a commonly used data visualization tool. They allow us to identify and determine if there is a relationship or correlation between two or more variables and the strength of that relationship.

In this pandas tutorial, I’ll show you simple methods to plot multiple scatter plots in a single plot.

In this tutorial, we use Medical Cost Personal Datasets from Kaggle. The data set has 1338 rows and 7 columns:

import pandas as pd

from matplotlib import pyplot as plt
from matplotlib.pyplot import figure

df=pd.read_csv("/content/insurance.csv")
df= df.dropna()
df.head()
Pandas Data frame for scatter plot
  • age: age of the primary beneficiary.
  • sex: insurance contractor gender, female, male
  • bmi: Body mass index, providing an understanding of the body, weights that are relatively high or low relative to height, objective index of body weight (kg / m ^ 2) using the ratio of height to weight, ideally 18.5 to 24.9
  • children: Number of children covered by health insurance / Number of dependents
  • smoker: is smoking
  • region: residential area in the US, northeast, southeast, southwest, northwest.
  • charges: Individual medical costs billed by health insurance

Now that we have our data loaded, we can create the scatter plot of our insurance data. In particular, we will use the age, bmi, and charges for medical cost analysis. From this data, we identify a number of different things about the medical cost. 

We will now look at plotting multiple scatters by superimposing them. In this method, we do not use any special function instead we directly plot the curves one above the other and try to set the scale. To create the scatter plot we can call upon the following code.

color = df['sex'].apply(lambda x: 'navy' if x == 'male' else 'gold')
bcolor = df['smoker'].apply(lambda x: 'red' if x == 'yes' else 'green')

figure(figsize=(10, 8), dpi=80)

plt.style.use('ggplot')

plt.title('Relation between age,bmi and charges')
plt.xlabel('age and bmi')
plt.ylabel('charges')
plt.scatter(x=df['age'],y=df['charges'],s=100,c=color,alpha=0.6,marker='o',edgecolors=bcolor,linewidth=df['children'])
plt.scatter(x=df['bmi'],y=df['charges'],s=100,c=color,alpha=0.5,marker='h',edgecolors=bcolor,linewidth=df['children'])

plt.legend(loc='upper right')


plt.tight_layout()

plt.show()

Now, this is the only 10 line of code and it’s pretty similar, you simply have to define the chart type that you want to plot, which is scatter. When plotting a scatter plot in pandas, you’ll always have to specify the x and y values as parameters.

This particular scatter plot shows the relationship between the age and charges of people and bmi and charges of people from a random sample.

Multiple scatter plot

y-axis shows the charges and x-axis shows the age and bmi and each dot represents a person in this dataset.

We could have plotted the same two-scatter plots above by calling the plot() function twice, illustrating that we can paint any number of charts onto the canvas.

You can keep adding plt.plot as many times as you like. As for multiple scatter plots, you need to specify the color, marker, edgecolor, alpha so you can differentiate them.

Related Post