Simple Python script for linear fits and plots

I've recently spent an unhealthily amount of time plotting, among most, linear plots. Some for looking at linear circuits behaviour in Physics 111 lab (as you can see, things get messy) and some for photometric analysis on the RC3 pipeline. To simplify the long plotting scripts in my IPython Notebooks, I wrote this handy method for linear regression (numpy's polyfit) and plotting. I have added a lot of more optional arguments since I first did this in the solar scan analysis Astro120 optical lab.

For tasks like these, this is the amount of plotting script you need to generate this plot:

In [3]:
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(input_mag,output_mag,',')
plt.xlabel("Input Magnitude")
plt.ylabel("Output Magnitude")
plt.title("Photometric Analysis")
z = np.polyfit(input_mag,output_mag, 1) 
p = np.poly1d(z)
a = np.linspace(min(input_mag),max(input_mag))
ax1.plot(a, p(np.linspace(min(input_mag),max(input_mag))),color="red")
ax1.text(0.03,0.85,"y= %.5f x + %.5f"%(z[0], z[1]), fontsize=13,transform=ax1.transAxes)
plt.tick_params(axis='both', which='major', labelsize=12)
plt.tick_params(axis='both', which='minor', labelsize=12)

Obviously this gets tedious when you are trying plot lots of these.

Introducing the fit_and_plot method:

In [4]:
def fit_and_plot(x,y,xlabel="",ylabel="",title="",zeroed=False,annotate_fit= True,right_words = False,error_bar="",sci_lim = False,annotate="",right_annotate=False,marker='o'):
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    ax1.plot(x,y,'{}'.format(marker))
    z = np.polyfit(x,y, 1) 
    p = np.poly1d(z)
    if zeroed : 
        a = np.linspace(0,max(x))
    else:
        a = np.linspace(min(x),max(x))
    ax1.plot(a, p(np.linspace(min(x),max(x))),color="red")
    if annotate_fit: 
        slope = z[0]
        intercept = z[1]
        if right_words:    
            ax1.text(0.48,0.85,"y= %.5f x + %.5f"%(slope,intercept), fontsize=13,transform=ax1.transAxes)
        else:
            ax1.text(0.03,0.85,"y= %.5f x + %.5f"%(slope,intercept), fontsize=13,transform=ax1.transAxes)
    if title !="":
        plt.title(title,fontsize=13 )    
    plt.xlabel(xlabel,fontsize=12)
    plt.ylabel(ylabel,fontsize=12)
    if annotate!="":
        if right_annotate: 
            ax1.text(0.48,0.85,annotate, fontsize=13,transform=ax1.transAxes)
    if error_bar!="":
        ax1.errorbar(x, y, yerr=error_bar, fmt='o')
    if sci_lim:
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tick_params(axis='both', which='minor', labelsize=12)
    return p

But now the one liner:

In [5]:
fit_and_plot(input_mag,output_mag,xlabel="Input Magnitude",ylabel="Output Magnitude",title="Photometric Analysis",marker=',')
Out[5]:
poly1d([ 0.94088895,  1.65621614])