135 lines
5.1 KiB
Python
135 lines
5.1 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import os
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import cv2
|
|
# from .scipy_misc import toimage
|
|
|
|
def plot_matrix(mat, path=None, xticks=None, yticks=None, xlim=None, ylim=None, figsize=(6,4), title=None, xlabel=None, ylabel=None, fontsize=20, cmap="YlGnBu"):
|
|
plt.figure(figsize=figsize)
|
|
# vis = vote_map.reshape(20, -1, 4).max(2)
|
|
with sns.axes_style("white"):
|
|
ax = sns.heatmap(mat, cmap=cmap)
|
|
if xticks is not None:
|
|
plt.xticks(xticks)
|
|
if yticks is not None:
|
|
plt.yticks(yticks)
|
|
if xlim is not None:
|
|
plt.xlim(xlim)
|
|
if ylim is not None:
|
|
plt.ylim(ylim)
|
|
if title is not None:
|
|
plt.title(title, fontsize=fontsize)
|
|
if xlabel is not None:
|
|
plt.xlabel(xlabel, fontsize=fontsize)
|
|
if ylabel is not None:
|
|
plt.ylabel(ylabel, fontsize=fontsize)
|
|
plt.tight_layout()
|
|
if path is not None:
|
|
plt.savefig(path)
|
|
plt.close()
|
|
else:
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
def plot_shaded(mid, high, low, path=None, xticks=None, yticks=None, xlim=None, ylim=None, figsize=(6,4), title=None, xlabel=None, ylabel=None, fontsize=20, color = '#5f87bc'):
|
|
dim = len(mid)
|
|
plt.figure(figsize=figsize)
|
|
lw = 1
|
|
plt.plot(mid, linewidth=lw, color=color)
|
|
plt.plot(np.zeros(dim), linewidth=1/2., color=color)
|
|
plt.fill_between(range(dim), low, high, linewidth=0.1, alpha=0.5, color=color)
|
|
if xticks is not None:
|
|
plt.xticks(xticks)
|
|
if yticks is not None:
|
|
plt.yticks(yticks)
|
|
if xlim is not None:
|
|
plt.xlim(xlim)
|
|
if ylim is not None:
|
|
plt.ylim(ylim)
|
|
if title is not None:
|
|
plt.title(title, fontsize=fontsize)
|
|
if xlabel is not None:
|
|
plt.xlabel(xlabel, fontsize=fontsize)
|
|
if ylabel is not None:
|
|
plt.ylabel(ylabel, fontsize=fontsize)
|
|
plt.tight_layout()
|
|
if path is not None:
|
|
plt.savefig(path)
|
|
else:
|
|
plt.show()
|
|
|
|
def plot_mean_std(mean, std, path=None, xticks=None, yticks=None, xlim=None, ylim=None, figsize=(6,4), title=None, xlabel=None, ylabel=None):
|
|
plot_shaded(mean, mean+std, mean-std, path=path, xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, figsize=figsize, title=title, xlabel=xlabel, ylabel=ylabel)
|
|
|
|
|
|
def plot_lines(lines, path=None, legends=[], xticks=None, yticks=None, xlim=None, ylim=None, figsize=(6,4), linewidth=1, title=None, xlabel=None, ylabel=None, fontsize=20):
|
|
# color = '#5f87bc'
|
|
plt.figure(figsize=figsize)
|
|
if len(legends) > 0:
|
|
for line, legend in zip(lines, legends):
|
|
plt.plot(line, label=legend, linewidth=linewidth)
|
|
plt.legend()
|
|
else:
|
|
plt.plot(lines, linewidth=linewidth)
|
|
if xticks is not None:
|
|
plt.xticks(xticks)
|
|
if yticks is not None:
|
|
plt.yticks(yticks)
|
|
if xlim is not None:
|
|
plt.xlim(xlim)
|
|
if ylim is not None:
|
|
plt.ylim(ylim)
|
|
if title is not None:
|
|
plt.title(title, fontsize=fontsize)
|
|
if xlabel is not None:
|
|
plt.xlabel(xlabel, fontsize=fontsize)
|
|
if ylabel is not None:
|
|
plt.ylabel(ylabel, fontsize=fontsize)
|
|
plt.tight_layout()
|
|
if path is not None:
|
|
plt.savefig(path)
|
|
else:
|
|
plt.show()
|
|
|
|
def plot_distrib(dist1, dist2=None, path=None, figsize=(6,4), xticks=None, yticks=None, xlim=None, ylim=None, linewidth=1, title=None, xlabel=None, ylabel=None, fontsize=20, fpr_shade=True, color='#5f87bc'):
|
|
plt.figure(figsize=figsize)
|
|
g1 = sns.distplot(pd.Series(data=dist1, name=''), color=color, hist=False, kde_kws={"shade": True, 'linewidth': linewidth})
|
|
if xlim is not None:
|
|
g1.set(xlim=xlim)
|
|
# g1.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], xlim=xlim, ylim=ylim)
|
|
# g1.set(xlabel=None, ylabel=None)
|
|
if dist2 is not None:
|
|
g2 = sns.distplot(pd.Series(data=dist2, name=''), color='#444444', hist=False, kde_kws={"shade": True, 'linewidth': linewidth},)
|
|
# g2.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], xlim=xlim, ylim=ylim)
|
|
# g2.set(xlabel=None, ylabel=None)
|
|
if fpr_shade:
|
|
arr = g2.get_children()[1].get_paths()[0].vertices
|
|
x, y = arr[:, 0], arr[:, 1]
|
|
x, y = x[y > 0], y[y > 0]
|
|
x, y = x[x.argsort()], y[x.argsort()]
|
|
mask = x > np.percentile(dist1, 5)
|
|
g2.fill_between(x[mask], y1=y[mask], y2=0, alpha=0.3, facecolor='#444444', hatch='////')
|
|
# sns.despine(bottom=True, left=True)
|
|
if title is not None:
|
|
plt.title(title, fontsize=fontsize)
|
|
if xlabel is not None:
|
|
plt.xlabel(xlabel, fontsize=fontsize)
|
|
if xlabel is not None:
|
|
plt.xlabel(xlabel, fontsize=fontsize)
|
|
if ylabel is not None:
|
|
plt.ylabel(ylabel, fontsize=fontsize)
|
|
plt.tight_layout()
|
|
if path is not None:
|
|
plt.savefig(path)
|
|
else:
|
|
plt.show()
|
|
|
|
|
|
# def colormap(img, mode=cv2.COLORMAP_JET):
|
|
# img = toimage(img)
|
|
# colormask = cv2.applyColorMap(np.array(img), mode)[:,:,::-1]
|
|
# return colormask |