บันทึก training data science EP 3: Matplotlib & Seaborn – แผนที่ดูแพง

บันทึก training data science EP 3: Matplotlib & Seaborn – แผนที่ดูแพง

ตอนก่อนหน้า: บันทึก training data science EP 2: Pandas & Matplotlib – ดูแผนที่ก่อนเดินทาง

ยะโฮ่ ทุกคนฮะ

ครั้งที่แล้วใน EP 2 ผมได้เล่าถึงการดูภาพรวมของข้อมูลไปแล้วเนอะ คราวนี้ เรามาลองเล่นกราฟแบบแพงๆ กันบ้างดีกว่าฮะ

ผมขอนิยามคำว่าแพงเป็นว่า ดูดี less is more อะไรแบบนี้ แต่มันไม่ได้ less จนคนดูไม่ได้อะไรจากกราฟเลยนะฮะ กลับกัน กราฟที่ดีที่แพง ควรจะแสดงข้อมูลให้คนใช้งานเห็นปราดเดียวเข้าใจภาพรวม และปราดที่สองที่สาม เค้าได้ insight ไปฮะ

ทวนกันนิดนึง กราฟที่ผมใช้ไปตอนที่แล้ว คือ lib matplotlib เนอะ เราจะลองอัพเกรดกราฟกันฮะ

เทรนด์มากกว่าหนึ่ง

กราฟเส้นมักใช้แสดงผลข้อมูลในเชิงความเปลี่ยนแปลงใช่มั้ยฮะ ถ้าเรามีข้อมูลพวกนี้ในหลายๆ มิติและอยากจะแสดงเส้นเทรนด์ของแต่ละมิติ หรือเรียกว่า breakdown ลองใช้วิธีนี้ดูนะฮะ

จากข้อมูล titanic ที่แล้ว เราเลือก column sex, age, และ fare โดยที่ age มีค่าฮะ

selector = titanic[titanic.Age.isna() == False][[titanic.Sex.name, titanic.Age.name, titanic.Fare.name]]
selector

จากนั้นหาค่าเฉลี่ยของ fare จากการแบ่งกลุ่มของ sex และ age ตรงนี้จะได้ผลลัพท์คือ DataFrameGroupByเราเอามาแปลงให้เป็น DataFrame ซึ่งตอนแรกจะมี index เป็น sex กับ age ผมไม่ต้องการแบบนั้นจึงใช้ .reset_index() เพราะผมต้องการอ้างอิง column ด้วยชื่อ แทนที่จะเป็น .index ฮะ

grouper = pd.DataFrame(selector.groupby([titanic.Sex.name, titanic.Age.name]).Fare.mean()).reset_index()
grouper

ถึงตรงนี้ ผมจะใช้ plt.subplots() เพื่อสร้างสองกราฟมารวมกันฮะ ฟังก์ชันนี้จะส่งค่ากลับมาสองตัวคือ fig หรือพื้นหลังกราฟ และ ax คือ เบื้องหน้าของกราฟ

(fig, ax) = plt.subplots()

ผมใช้ ax สร้างกราฟอันแรกเป็น age แกนนอน และ fare แกนตั้งของ female กำหนดสีเป็น red และตั้งชื่อกราฟว่า female ส่วนกราฟที่สองทำเหมือนกันเลยแต่เป็นข้อมูลของ male ใช้สี blue และชื่อกราฟคือ male

ax.plot(grouper[grouper.Sex == 'female'].Age, grouper[grouper.Sex == 'female'].Fare, c='red', label='female')
ax.plot(grouper[grouper.Sex == 'male'].Age, grouper[grouper.Sex == 'male'].Fare, c='blue', label='male')

.legend() เอาไว้แปะป้ายชื่อกราฟแต่ละเส้นที่เห็นเป็นกรอบเล็กๆ ที่ขวาบนน่ะฮะ

plt.legend()

.xlabel() และ .ylabel() ใช้ตั้งชื่อแกนนอนและแกนตั้งตามลำดับฮะ

plt.xlabel('Age')
plt.ylabel('Fare')

ว่าไป ไม่มีลูกเรือเป็นผู้หญิงเกิน 60 กว่าปีเลยเน้อ

กระจายแบบสเปกตรัม

Scatter graph เป็นกราฟลักษณะเป็นจุดกระจายๆ ไปตามค่าของข้อมูลนะฮะ โดยพื้นฐาน มันจะบอกข้อมูลสองข้อจากพิกัดแกน x และ y แต่คราวนี้ เราจะมีข้อที่สามเข้ามาฮะ

selector = titanic[titanic.Age.isna() == False][[titanic.Pclass.name, titanic.Age.name, titanic.Fare.name]]
selector

เราเลือกแถวที่ Age มีค่า จากนั้นเลือก column Pclass, Age และ Fare

selector.plot.scatter(x = selector.Age.name, y = selector.Fare.name, c = selector.Pclass, cmap = plt.cm.rainbow)

จากนั้นใช้คำสั่ง .plot.scatter() จาก DataFrame ได้เลย ระบุแกน x และ y

ทีนี้ ผมใส่ parameter c และ cmap เข้ามาด้วย มันมีความหมายแบบนี้ฮะ

  • c
    color คือเราจะระบุ column ที่ใช้แยกเป็นสีๆ
  • cmap
    ใช้ class ของ matplotlib.cm เอาไว้ระบุสเกลของสีที่เราระบุในตัวแปร c ข้างบนฮะ

ด้วยความที่ Pclass มีสามค่า เลยกลายเป็นกราฟกระจายที่มีสามสี ม่วง แดง และเขียวฮะ


ความสัมพันธ์ทับซ้อน

คราวนี้มาลองเล่น stacked bar graph บ้าง มันคือกราฟแท่งที่ซ้อนกันเพื่อแสดงค่าผลรวมของแต่ละ breakdown ได้ฮะ

selector = titanic[[titanic.Pclass.name, titanic.Sex.name]]
selector

เริ่มต้นด้วยว่า ผมอยากจะดูกราฟระหว่าง Pclass และ sex

จากนั้นเอามา .groupby() ตามสอง column นั้นและหาจำนวนสมาชิกของแต่ละกลุ่ม

grouper = pd.DataFrame(selector.groupby([selector.Pclass.name, selector.Sex.name]).Sex.count()) \
  :

ในตอนนี้ เราจะได้ DataFrameGroupBy ก็เอามาทำให้เป็น DataFrame ด้วย .DataFrame() ซึ่งมันจะมี index เพียงที่มี columns เพียงตัวเดียวคือ Sex ที่ได้จาก .count() ส่วน Index คือ column ที่ใช้ .groupby() ฮะ ได้แก่ Pclass และ Sex

  :
    .add_suffix("_count") \
    .reset_index() \
    .set_index(selector.Pclass.name)

ทีนี้ ผมจะสร้าง stacked bar โดยให้ Pclass เป็นฐานอ้างอิงหรือก็คือ แกน x ผมจะต้องสั่ง .reset_index()แล้วค่อย .set_index("Pclass") แต่เราจะทำแบบนั้นทันทีไม่ได้ฮะ เพราะถ้าสั่ง .reset_index() เราจะเจอว่ามี column Sex ซ้ำกัน จึงใช้ .add_suffix() เพื่อเปลี่ยน column Sex ในตอนแรกให้มีคำว่า “_count” ต่อท้าย กลายเป็น “Sex_count” ถึงจะสั่ง .reset_index() ได้นั่นเองฮะ

อ่านเพิ่มเรื่อง .add_suffix() ได้ที่นี่ฮะ

แล้วค่อยสั่ง .set_index() ทีนี้เราจะได้ผลลัพท์ตามรูปข้างล่างฮะ

เมื่อเราคำนวณเสร็จแล้ว ก็เอามาทำกราฟได้แล้วล่ะ

(fig, ax) = plt.subplots()
bar_width = 0.5
_class = grouper[grouper.Sex == 'male'].index
_male_count = grouper[grouper.Sex == 'male'].Sex_count
_female_count = grouper[grouper.Sex == 'female'].Sex_count

เราใช้ .subplots() เพื่อประกอบร่างกราฟขึ้นมาฮะ จากนั้นผมสร้างตัวแปรมาอีกสี่ตัวคือ

  • bar_width
    เก็บค่าความกว้างของแท่งกราฟ
  • _class
    สำหรับเก็บ Index หรือก็คือค่า Pclass
  • _male_count
    เก็บค่า Sex_count ของแถวที่มี Sex เป็น male
  • _female_count
    เก็บค่า Sex_count ของแถวที่มี Sex เป็น female
# add male
ax.bar(_class, _male_count, bar_width, label = 'male')

ผมวางกราฟของ male ไปก่อนด้วยคำสั่ง ax.bar() โดยมี parameters สี่ตัว ได้แก่

  1. ค่าแกน x ก็คือ ตัวแปร _class
  2. ค่าแกน y คือ ตัวแปร _male_count
  3. ค่าความกว้างของแท่ง คือตัวแปร bar_width ที่ประกาศเตรียมไว้แล้ว
  4. label = “male” อันนี้ไว้สำหรับใช้ในกล่องอธิบายกราฟแต่ละแท่งฮะ
# add female
ax.bar(_class, _female_count, bar_width, bottom = list(_male_count), label = 'female')

เสร็จจาก male ก็มาที่ female แต่รอบนี้จะต้องวางแท่ง female ไว้บน male จะได้กลายเป็น stacked bar chart นะฮะ ดังนั้นเราจะมี parameters เพิ่ม ดังนี้

  1. ค่าแกน x เหมือนเดิมก็คือ _class
  2. ค่าแกน y จะเป็น _female_count
  3. ความกว้างของแท่ง เป็น bar_width
  4. bottom = x ตรงนี้แหละจะเป็นตัวบอกว่าแท่งของเราจะยกขึ้นเท่าไหร่ ต้องการเป็นตัวแปรประเภทค่าคงที่ Int หรือ Float ไม่ก็เป็น Array ทีนี้เราต้องการให้มันยกสูงเท่ากับค่า _male_count แต่มันเป็น Series เราต้องเปลี่ยนให้เป็น Array ด้วยคำสั่ง list() ฮะ
  5. label = “female” สำหรับกล่องอธิบายกราฟ

จากนั้น ผมก็ใช้ .set_xticks() เพื่อแปะค่าแกน x ตามตัวแปร _class

ax.set_xticks(_class)

ใช้ .legend() เพื่อวาดกล่องอธิบายกราฟ ที่เราเห็นตรงซ้ายบนน่ะแหละฮะ

ax.legend()

จากนั้นก็แปะตัวเลขของแต่ละกล่องลงไปในกราฟฮะ

for p in ax.patches:
    x, y = p.get_xy()
    h = p.get_height()
    ax.annotate(str(h), ((x + p.get_width()/2), (h+y)), ha='center', va='top', color='white')

.patches จะได้เป็น list ของชิ้นส่วนกราฟ ในเคสนี้ มันคือเจ้ากล่องสี่เหลี่ยมผืนผ้าที่ประกอบๆ กันในกราฟนี้ฮะ เราจะเช็คชิ้นส่วนที่ว่าแต่ละชิ้น แล้วใช้คำสั่ง .annotate() เพื่อแปะข้อความไปฮะ โดย parameters ที่ผมใช้ คือ

  • ตัวแปรแรก เป็น str(h) คือแปลงค่า h เป็นข้อความ โดยที่ h มาจาก p.get_height() หรือก็คือ ความสูงของชิ้นส่วนนั้น มองย้อนกลับไป มันก็คือค่า _male_count หรือ _female_count นั่นเองฮะ
  • ตัวแปรที่สอง เป็นพิกัดของข้อความ ประกอบด้วย
  • x มีค่า x + p.get_width()/2 แปลว่า หาค่าแกน x ที่มุมซ้าย บวกด้วยความกว้างของชิ้นส่วน ค่าที่ได้คือกึ่งกลางของแท่งกราฟฮะ
  • y มีค่า h + y แปลว่าเอาความสูงของแท่งกราฟนั้น มาบวกกับค่าแกน y ที่มุมซ้าย หรือฐานของแท่ง ค่าที่ได้คือ ขอบบนของแท่งฮะ
  • ha คือ horizontal alignment ผมใช้ center เพื่อวางให้ข้อความอยู่ตรงกลาง
  • va คือ vertical alignment ผมใช้ top เพื่อวางให้ข้อความอยู่ขอบบน
plt.xlabel('Ticket class')
plt.ylabel('crew count')

จากนั้นก็ .xlabel() และ .ylabel() แปะข้อความประจำแกน x และ y ฮะ

สำหรับข้อมูลเพิ่มเติม ลองอ่านตามลิงก์พวกนี้ฮะ

Seaborn

seaborn เป็น library อีกตัวที่เอาไว้สร้างกราฟสวยๆ เรามักจะย่อชื่อเป็น sns ซึ่งมาจากชื่อ Samuel Norman Seaborn เป็นการล้อชื่อตัวละครเพราะมีคำว่า Seaborn นั่นเองฮะ (ที่มา)

เราลองมาดูตัวอย่างกันคร่าวๆ ฮะ

Joint plot

sns.jointplot() สร้างกราฟจากข้อมูลสองชุด เพื่อหาความสัมพันธ์ในสองรูปแบบ ที่กำหนดด้วย parameter “kind”

kde (kernel density estimate) ระบุความหนาแน่นของข้อมูลพร้อมกราฟเส้น

scatter ระบุการกระจายของข้อมูลพร้อมกราฟฮิสโทแกรม

reg (regression) ระบุเส้นแนวโน้ม พร้อมฮิสโทรแกรมและกราฟเส้น

resid (residual) บอกแนวโน้ม และฮิสโทแกรมของค่าผิดพลาดของเส้นแนวโน้ม

hex (hexagon) เป็นกึ่ง scatter กึ่ง heatmap ที่แสดงผลเป็นรูปหกเหลี่ยม

ข้อมูลอ้างอิง http://alanpryorjr.com/visualizations/seaborn/jointplot/jointplot/

Pair plot

sns.pairplot() หยิบทุก column ที่เป็นตัวเลขมา plot graph เป็นคู่ๆ

Relational plot

sns.relplot() ขั้นกว่าของ scatter plot ที่เราสามารถ breakdown ได้อีกสองระดับด้วย parameter “hue” ที่แยกสีและ “size” ที่แยกตามขนาด

Violin plot

sns.violinplot() ที่แสดงความหนาแน่นในแต่ละ breakdown

Heat map

sns.heatmap() ที่สร้างแผนภูมิความถี่ด้วยสี

ด้วยความที่ ฟังก์ชันนี้ต้องการตัวแปรที่มีค่าเป็นตัวเลขทั้งหมด ผมจำเป็นต้องประมวลผลค่านิดนึงฮะ

เริ่มจาก ผมต้องการดูว่าแต่ละ Pclass มี Sex ไหนบ้างที่ Survived อยู่จำนวนเท่าไหร่ จึงต้อง .groupby() และ .sum() ค่า Survived

จากนั้นใช้ .pivot() เพื่อสร้าง pivot table ฮะ

และสุดท้ายก็จะได้ heatmap พร้อมตัวเลขระบุ ที่ได้จากการใส่ parameter annot เป็น True

ข้อมูลเพิ่มเติม https://seaborn.pydata.org/api.html


นึกว่าจะเขียนสั้น แต่เอาเข้าจริง เนื้อหาเยอะเหมือนกันนะฮะ

สรุปว่า seaborn ใช้ทำกราฟง่าย และสวย แต่ถ้าอยาก custom มันเยอะๆ ผมว่า matplotlib ตอบโจทย์กว่าฮะ

พอก่อนเนอะ เริ่มเหนื่อยละ

คราวหน้าจะเป็นอะไร จะมาเล่าให้อ่านนะฮะ

บาย~

ตอนต่อไป: บันทึก training data science EP 4: Scikit-learn & Linear Regression – แนวโน้มของเส้นตรง

Show Comments