"""
Read a single 256×256 Sentinel-1/Space-LiDAR patch from a ZIP archive and
compute the 4×4 polarimetric–interferometric sample covariance matrix.


-----------
- The ZIP has the structure: <Region>/<SentinelGedi|SentinelIcesat2>/<lat_lon>/<*.mat>
- Each .mat contains:
    I (256×256×4 complex single)  -> S1 channels [VVm, VHm, VVs, VHs] (example order)
    IncidentAngle, Latitude, Longitude, Range (each 256×256 single)
    Lidar (struct)                -> fields like 'Lidar RH98', 'Lidar ground elevation'
    dataInfo (struct)             -> metadata (baselines, dates, etc.)

Outputs
-------
- `RH98 and DTM representing forest and ground heights
- `C` 4×4×256×256 complex array: windowed sample covariance
"""



import zipfile
import os
import scipy.io
import io
import scipy.constants as const
import sys
eps= sys.float_info.epsilon
import numpy as np
from scipy.ndimage import uniform_filter
import matplotlib.pyplot as plt
import h5py
from pathlib import Path


# -------------------------------------------------------------------------
# Address of zip file (continent/region archive to read in-memory).
# Example: Europe → Finland_Sweden_Norway.zip
# -------------------------------------------------------------------------
zip_path = './Europe/Finland_Sweden_Norway.zip'
data_folder, _ = os.path.splitext(Path(zip_path).parts[-1])
mapping_lidar = {
    "SentinelIcesat2": "Icesat2",
    "SentinelGedi": "Gedi"}

# Sentinel-1 wavelength in meters (~C-band, 5.6 cm).
lambdaSent = 0.056;

# -------------------------------------------------------------------------
# Open the ZIP and enumerate all .mat patch files that belong to the dataset.
# 'z.namelist()' returns all members; we keep only those under 'data_folder'.
# -------------------------------------------------------------------------
z = zipfile.ZipFile(zip_path, 'r')  
mat_files_list = [f for f in z.namelist()
                  if os.path.normpath(f).startswith(data_folder) and f.endswith('.mat')]



# -------------------------------------------------------------------------
# Read ONE .mat patch (for a specific kk)
# -------------------------------------------------------------------------
kk = 0                                   # choose the index you want
file_path = mat_files_list[kk]           # path inside the zip

with z.open(file_path) as file:
    second_part = Path(file_path).parts[1]
    lidar_name = mapping_lidar.get(second_part, 0)

    # Load MATLAB data directly from the in-memory bytes buffer
    mat_data = scipy.io.loadmat(io.BytesIO(file.read()))

    # ---- Metadata (dataInfo struct) -----------------------------------
    dataInfo = mat_data['dataInfo']
    dataInfo = {field: dataInfo[field][0, 0] for field in dataInfo.dtype.names}
    bn = float(dataInfo['Perp Baseline'][0])      # perpendicular baseline [m]
    mstDate = (dataInfo['SAR master date'][0])    # master date
    slvDate = (dataInfo['SAR slave date'][0])     # slave date

    # ---- Core arrays ---------------------------------------------------
    I = mat_data['I'].astype(np.complex64)        # (256,256,4)
    Latitude  = mat_data['Latitude'].astype(np.float32)
    Longitude = mat_data['Longitude'].astype(np.float32)
    theta     = mat_data['IncidentAngle'].astype(np.float32)  # degrees
    Range     = mat_data['Range'].astype(np.float32)          # meters

    # ---- LiDAR (struct) -----------------------------------------------
    Lidar = mat_data['Lidar']
    Lidar = {field: Lidar[field][0, 0] for field in Lidar.dtype.names}
    RH98 = np.asarray(Lidar['Lidar RH98']).astype(np.float32)             # canopy height
    DTM  = np.asarray(Lidar['Lidar ground elevation']).astype(np.float32) # ground elevation

    # ---- Vertical wavenumber (kz) -------------------------------------
    theta_rad = np.deg2rad(theta)
    kz = (2.0*np.pi*bn) / (eps + lambdaSent * Range * np.sin(theta_rad))  # (256,256)

    # ---- 4×4 polarimetric–interferometric covariance ------------------
    # Mean of pairwise products over a 5×5 window. Compute all 16 products
    # at once via broadcasting, then smooth only across spatial dims.
    # prod shape: (H,W,4,4) -> windowed mean with size=(5,5,1,1)
    prod = I[..., :, None] * np.conj(I[..., None, :])                     # (256,256,4,4)
    cov_real = uniform_filter(np.real(prod), size=(5,5,1,1), mode="reflect")
    cov_imag = uniform_filter(np.imag(prod), size=(5,5,1,1), mode="reflect")
    covariance = (cov_real + 1j*cov_imag).transpose(2,3,0,1).astype(np.complex64)  # (4,4,256,256)

    # Now you have: RH98, DTM, kz, and covariance for this single patch.






