#!/usr/bin/env python3
# 2024/09/16: (c) Sean Parkin
# A python program that draws Diederichs plots of an unmerged SHELXL *.hkl file. 

import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Function to read data from the *.hkl file
def read_data(file_path):
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            stripped_line = line.strip()
            # Stop reading if the line is blank or contains all zeroes
            if not stripped_line or all(float(val) == 0 for val in stripped_line.split()):
                break
            data.append([float(val) for val in stripped_line.split()])
    return np.array(data)

# Function to plot the Diederichs scatter plot
def plot_scatter(data, filename):
    # Extracting the 4th and 5th columns (Python indexing starts at 0)
    I_values = data[:, 3]  # 4th column (I values)
    y = I_values / data[:, 4]  # 4th column divided by the 5th column (I/sigma(I))

    # Filter out non-positive values before applying log10
    mask = I_values > 0
    I_values_positive = I_values[mask]
    y_filtered = y[mask]

    # Apply log10 on the filtered values
    x_filtered = np.log10(I_values_positive)

    # Further filtering points where x (log10 of column 4) is > -1
    mask_log = x_filtered > -1
    x_filtered = x_filtered[mask_log]
    y_filtered = y_filtered[mask_log]
    
    # Plotting
    plt.scatter(x_filtered, y_filtered, s=1)

    # Title with filename (without the .hkl extension)
    plt.xlabel(r'log$_{10}$(I)')  # Subscripted 10 in the x-label
    plt.ylabel(r'I/$\sigma$(I)')  # I/sigma(I) with sigma as a Greek letter

    # Make axis numbers italic
    plt.xticks(fontstyle='italic')
    plt.yticks(fontstyle='italic')
    
    # Save and show plot
    plt.savefig(filename + '.png', dpi=400)
    plt.show()

# Main function
def main():
    if len(sys.argv) != 2:
        print("Usage: python ddrch.py <data_file>")
        sys.exit(1)

    file_path = sys.argv[1]
    
    # Extract filename without the .hkl extension
    filename = os.path.basename(file_path)
    if filename.endswith('.hkl'):
        filename = filename[:-4]  # Remove the '.hkl' extension
    
    data = read_data(file_path)
    plot_scatter(data, filename)

# Entry point
if __name__ == '__main__':
    main()
