AI | ML/코드

[Data Science] matplotplib을 이용한 데이터셋 가시화

깜태 2021. 2. 17. 17:55
728x90

1. 이미지를 읽어오는 시간이 느려서 thread를 이용해 동시성을 증가시켰습니다. ( thread가 데이터 개수만큼 생기니 발열 주의 )

2. 데이터셋마다 평균 표준편차의 값이 다를 수 있으니, np.arange부분의 x축 설정을 주의하셔야합니다.

 

import cv2
import numpy as np
import glob
import argparse
import matplotlib.pyplot as plt
from scipy import stats
from threading import Thread

parser = argparse.ArgumentParser()
parser.add_argument('--path', help='input your dataset folder')
args = parser.parse_args()

legend = list()

def calc_mean_std(data, mean_list, std_list): # 이미지당 mean, std 계산
    # print(data)
    img = cv2.imread(data, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
    if len(img.shape) <2: # 흑백 이미지는 제외
        return
    mean,std = np.mean(img, axis=(0, 1)), np.std(img, axis=(0, 1))
    mean_list.append(mean)
    std_list.append(std)
    # return mean, std

def plot_mean_std(mean, std): # mean,std 를 plt로 plot
    # mean = np.mean(mean, axis=0)
    # std = np.std(std, axis=0)
    x_data = np.arange(-1, 1, 0.001) # x축 범위를 설정
    color = ['r', 'g', 'b']
    global legend
    for i, (_mean, _std) in enumerate(zip(mean, std)):
        legend.append("N({:.3f}, {:.3f})".format(_mean,_std))
        y_data = stats.norm.pdf(x_data, _mean, _std)
        plt.plot(x_data, y_data, color[i], label="{}_{}".format(_mean, _std) )
    
if __name__=="__main__":
    path = args.path
    data_list = tuple(glob.glob(path+'/*.png')) # 불러올 이미지 설정
    print(len(data_list))
    mean_list = list()
    std_list = list()
    # Thread를 이용한 속도개선
    for idx, data in enumerate(data_list):
        thread = Thread(target=calc_mean_std, args=(data, mean_list, std_list))
        thread.start()
        if idx % 1000 == 0:
            print(idx, "is Started")
        # thread.join()
    
    thread.join()
    print(len(mean_list), len(std_list))
    
    mean_list, std_list = np.asarray(mean_list), np.asarray(std_list)
    mean, std = np.mean(mean_list, axis=0), np.std(std_list, axis=0)
    print("Mean : " , mean, "std : ", std)
    
    # 가시화 부분
    plot_mean_std(mean, std)
    plt.title('{}'.format(path.split('/')[-1]))
    plt.legend(legend)
    #plt.imshow()
    plt.savefig('distribution.png')

thread 적용 전
thread 적용 후

[참고] 

[1] matplotlib : m.blog.naver.com/PostView.nhn?blogId=parksehoon1971&logNo=221576978566&proxyReferer=https:%2F%2Fwww.google.com%2F

[2] thread : stackoverflow.com/questions/6893968/how-to-get-the-return-value-from-a-thread-in-python

[3] 

728x90