def plot_gauge_transformation_effects(
MCM_to_transform,
t_width_factor: float = 1.20,
verbose: bool = False,
MCM_reference: List[Instrument2x2] = None,
resolution: float = 1e5,
p00_min: float = 0.5,
margin_tol: float = 0.0
):
"""
Analyzes and plots the effect of a gauge transformation on an instrument.
This function takes an instrument, applies a gauge transformation over a range
of the gauge parameter 't', and plots how each of the 8 matrix entries evolves.
It uses analytical methods to determine valid gauge parameter regions.
Note: t=0.5 is excluded from the analysis as the gauge transformation matrix
is non-invertible at that point.
Args:
MCM_to_transform: An Instrument2x2 object to be transformed.
t_width_factor: Factor to scale the plotting range around valid regions. Default 1.20.
1.00 means plot exactly the valid regions, >1.00 adds padding.
verbose: If True, print detailed information about valid intervals. Default False.
MCM_reference: List of reference Instrument2x2 objects to compare RMSE against. Default [].
resolution: Number of points to sample per unit t-range.
p00_min: Minimum value for M0[0,0] when identifying focus regions. Default 0.5.
margin_tol: Margin tolerance for allowed regions. Default 0.0.
Returns:
Tuple containing:
- Instrument2x2 object constructed from center points of valid entry ranges, or None if no valid ranges.
- List of valid t-regions for the center_instrument (relative to itself).
"""
if MCM_reference is None:
MCM_reference = []
# Helper for formatting value column
def fmt_val_err(min_v, max_v):
c = (min_v + max_v) / 2
h = (max_v - min_v) / 2
# Determine exponent from the larger of abs(c) or abs(h) to avoid tiny numbers if c is near zero
ref = abs(c) if abs(c) > 0 else abs(h)
if ref == 0:
return "(0.000 ± 0.000)e+0"
exponent = int(np.floor(np.log10(ref)))
scale = 10.0 ** (-exponent)
c_s = c * scale
h_s = h * scale
return f"({c_s:.3f} ± {h_s:.3f})e{exponent:+d}"
if verbose:
print("Original Instrument to be transformed:")
MCM_to_transform.reveal()
# Use analytical method to find valid t-regions
valid_t_regions = allowed_t_regions_for_list(
[MCM_to_transform.M0, MCM_to_transform.M1],
tol=1e-24,
margin_tol=margin_tol
)
if not valid_t_regions:
raise ValueError(f"No valid gauge parameter regions found for this instrument (margin_tol={margin_tol}).")
if verbose:
print(f"\nAnalytically determined valid t-regions (total: {len(valid_t_regions)}) with margin {margin_tol}:")
for i, (t_lo, t_hi) in enumerate(valid_t_regions, 1):
print(f" Region {i}: t ∈ [{t_lo:.6f}, {t_hi:.6f}] (width: {t_hi - t_lo:.6f})")
# Determine plotting range based on valid regions and width factor
all_t_mins = [r[0] for r in valid_t_regions if not np.isinf(r[0])]
all_t_maxs = [r[1] for r in valid_t_regions if not np.isinf(r[1])]
if all_t_mins and all_t_maxs:
t_plot_min = min(all_t_mins)
t_plot_max = max(all_t_maxs)
t_center = (t_plot_min + t_plot_max) / 2
t_half_span = (t_plot_max - t_plot_min) / 2
# Apply width factor
t_plot_min = t_center - t_half_span * t_width_factor
t_plot_max = t_center + t_half_span * t_width_factor
else:
# Fallback if regions are unbounded
t_plot_min = -0.5
t_plot_max = 1.5
# Ensure we don't include t=0.5 in our sampling
if abs(t_plot_min - 0.5) < 1e-6:
t_plot_min = 0.5 - 1e-6
if abs(t_plot_max - 0.5) < 1e-6:
t_plot_max = 0.5 + 1e-6
# Generate t values for plotting, excluding t=0.5
n_points = int(resolution * (t_plot_max - t_plot_min))
if t_plot_min < 0.5 < t_plot_max:
t_values_left = np.linspace(t_plot_min, 0.5 - 1e-6, n_points // 2)
t_values_right = np.linspace(0.5 + 1e-6, t_plot_max, n_points // 2)
t_values = np.concatenate([t_values_left, t_values_right])
elif t_plot_max < 0.5:
t_values = np.linspace(t_plot_min, t_plot_max, n_points)
else:
t_values = np.linspace(t_plot_min, t_plot_max, n_points)
# Store the 8 entries of the transformed instrument for each value of t
transformed_entries = []
for t_val in t_values:
M0_prime, M1_prime = gauge_transform_instrument_numerically(
MCM_to_transform.M0, MCM_to_transform.M1, t_val
)
entries = np.concatenate((M0_prime.flatten(), M1_prime.flatten()))
transformed_entries.append(entries)
transformed_entries = np.array(transformed_entries)
# Use rebase_and_anchor_instrument to find focus regions with ok_p00
rebased_results = rebase_and_anchor_instrument(
MCM_to_transform.M0,
MCM_to_transform.M1,
valid_t_regions,
p00_min=p00_min
)
# Find first region with ok_p00 for focus plot
focus_region_info = None
for res in rebased_results:
if res['ok_p00']:
focus_region_info = res
break
# Determine number of subplots
has_focus_plot = focus_region_info is not None or len(MCM_reference) > 0
if has_focus_plot:
fig = plt.figure(figsize=(14, 28))
gs = fig.add_gridspec(4, 1, height_ratios=[1.2, 1.2, 1.2, 2.4], hspace=0.15)
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1], sharex=ax1)
axes = [ax1, ax2]
if len(MCM_reference) > 0:
ax3 = fig.add_subplot(gs[2], sharex=ax1)
axes.append(ax3)
ax4 = fig.add_subplot(gs[3])
axes.append(ax4)
else:
fig, axes = plt.subplots(2, 1, figsize=(14, 14), sharex=True)
if isinstance(axes, plt.Axes):
axes = [axes]
ax1 = axes[0]
ax2 = axes[1]
# ===== First subplot: Individual matrix entries =====
labels = [
r"$M^0$[0,0] = $p_0^{(0,0)}$",
r"$M^0$[0,1] = $p_1^{(0,0)}$",
r"$M^0$[1,0] = $p_0^{(0,1)}$",
r"$M^0$[1,1] = $p_1^{(0,1)}$",
r"$M^1$[0,0] = $p_0^{(1,0)}$",
r"$M^1$[0,1] = $p_1^{(1,0)}$",
r"$M^1$[1,0] = $p_0^{(1,1)}$",
r"$M^1$[1,1] = $p_1^{(1,1)}$"
]
for i in range(8):
ax1.plot(t_values, transformed_entries[:, i], linewidth=1.0, label=labels[i])
ax1.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')
ax1.axhline(y=0, color='k', linestyle='--', linewidth=1.0, label='Prob Boundary (0,1)')
ax1.axhline(y=1, color='k', linestyle='--', linewidth=1.0)
if margin_tol > 0:
ax1.axhline(y=-margin_tol, color='gray', linestyle=':', linewidth=1.0, label=f'Margin (±{margin_tol})')
ax1.axhline(y=1+margin_tol, color='gray', linestyle=':', linewidth=1.0)
# Highlight valid regions using analytical results
for i, (t_lo, t_hi) in enumerate(valid_t_regions):
# Clip to plotting range
t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max
if t_lo_plot < t_hi_plot:
label = 'Valid Gauge Region (t)' if i == 0 else ""
ax1.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)
ax1.set_ylabel("Value of Instrument Matrix Entry")
ax1.set_title(f"Evolution of Instrument Entries (Valid Interval: [{-margin_tol}, {1+margin_tol}])")
ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax1.grid(True, linestyle=':', alpha=0.6)
ax1.set_ylim(max(np.min(transformed_entries), -0.1 - margin_tol), min(np.max(transformed_entries), 1.1 + margin_tol))
# ===== Compute valid ranges for matrix entries =====
entry_labels = [
"M^0[0,0] = p_0^(0,0)",
"M^0[0,1] = p_1^(0,0)",
"M^0[1,0] = p_0^(0,1)",
"M^0[1,1] = p_1^(0,1)",
"M^1[0,0] = p_0^(1,0)",
"M^1[0,1] = p_1^(1,0)",
"M^1[1,0] = p_0^(1,1)",
"M^1[1,1] = p_1^(1,1)"
]
entry_ranges_data = []
center_values = [] # Store center values for constructing return instrument
for entry_idx, entry_label in enumerate(entry_labels):
if len(valid_t_regions) > 0:
t_lo, t_hi = valid_t_regions[0]
mask = (t_values >= t_lo) & (t_values <= t_hi)
if np.any(mask):
block_values = transformed_entries[mask, entry_idx]
min_val = np.min(block_values)
max_val = np.max(block_values)
width = max_val - min_val
center = (min_val + max_val) / 2
center_values.append(center)
row_data = {
'Entry': entry_label,
'Min': f"{min_val:.8f}",
'Max': f"{max_val:.8f}",
'Width': f"{width:.5e}",
'Value': fmt_val_err(min_val, max_val)
}
# Add comparison columns for each reference instrument
for ref_idx, ref_inst in enumerate(MCM_reference):
ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
ref_val = ref_entries[entry_idx]
# Calculate absolute discrepancy from center
abs_disc = ref_val - center
# Check if reference value is within range
if min_val <= ref_val <= max_val:
row_data[f'Ref{ref_idx+1}'] = f"IN ({abs_disc:+.5e})"
else:
direction = "above" if ref_val > max_val else "below"
row_data[f'Ref{ref_idx+1}'] = f"OUT {direction} ({abs_disc:+.5e})"
entry_ranges_data.append(row_data)
# Construct Instrument2x2 from center values
center_instrument = None
if len(center_values) == 8:
M0_center = np.array([[center_values[0], center_values[1]],
[center_values[2], center_values[3]]], dtype=float)
M1_center = np.array([[center_values[4], center_values[5]],
[center_values[6], center_values[7]]], dtype=float)
center_instrument = Instrument2x2(M0=M0_center, M1=M1_center)
# ===== Second subplot: Derived quantities =====
prep0_meas1 = transformed_entries[:, 4] + transformed_entries[:, 6]
prep1_meas0 = transformed_entries[:, 1] + transformed_entries[:, 3]
prep0_excite = transformed_entries[:, 2] + transformed_entries[:, 6]
prep1_decay = transformed_entries[:, 1] + transformed_entries[:, 5]
derived_quantities = {
"prep 0 meas 1": prep0_meas1,
"prep 1 meas 0": prep1_meas0,
"prep 0 excite to 1": prep0_excite,
"prep 1 decay to 0": prep1_decay
}
quantity_valid_ranges = {}
quantity_ranges_data = []
for quantity_name, quantity_values in derived_quantities.items():
if len(valid_t_regions) > 0:
# Check first valid region
t_lo, t_hi = valid_t_regions[0]
mask = (t_values >= t_lo) & (t_values <= t_hi)
if np.any(mask):
block_values = quantity_values[mask]
# We accept whatever values are in the valid region defined by matrix entries
# (removed the check for derived quantity validity to ensure all are shown)
min_val = np.min(block_values)
max_val = np.max(block_values)
width = max_val - min_val
center = (min_val + max_val) / 2
quantity_valid_ranges[quantity_name] = (min_val, max_val)
row_data = {
'Quantity': quantity_name,
'Min': f"{min_val:.8f}",
'Max': f"{max_val:.8f}",
'Width': f"{width:.5e}",
'Value': fmt_val_err(min_val, max_val)
}
# Add comparison columns for each reference instrument
for ref_idx, ref_inst in enumerate(MCM_reference):
# Compute reference quantity value
if quantity_name == "prep 0 meas 1":
ref_val = ref_inst.M1[0, 0] + ref_inst.M1[1, 0]
elif quantity_name == "prep 1 meas 0":
ref_val = ref_inst.M0[0, 1] + ref_inst.M0[1, 1]
elif quantity_name == "prep 0 excite to 1":
ref_val = ref_inst.M0[1, 0] + ref_inst.M1[1, 0]
elif quantity_name == "prep 1 decay to 0":
ref_val = ref_inst.M0[0, 1] + ref_inst.M1[0, 1]
else:
ref_val = 0.0
# Calculate absolute discrepancy from center
abs_disc = ref_val - center
# Check if reference value is within range
if min_val <= ref_val <= max_val:
row_data[f'Ref{ref_idx+1}'] = f"IN ({abs_disc:+.5e})"
else:
direction = "above" if ref_val > max_val else "below"
row_data[f'Ref{ref_idx+1}'] = f"OUT {direction} ({abs_disc:+.5e})"
quantity_ranges_data.append(row_data)
# Plot the derived quantities
labels_with_ranges = [
(r"prep 0 meas 1: $p_0^{(1,0)} + p_0^{(1,1)}$", "prep 0 meas 1"),
(r"prep 1 meas 0: $p_1^{(0,0)} + p_1^{(0,1)}$", "prep 1 meas 0"),
(r"prep 0 excite to 1: $p_0^{(0,1)} + p_0^{(1,1)}$", "prep 0 excite to 1"),
(r"prep 1 decay to 0: $p_1^{(0,0)} + p_1^{(1,0)}$", "prep 1 decay to 0")
]
quantity_list = list(derived_quantities.items())
for idx, ((base_label, quantity_key), (quantity_name, quantity_values)) in enumerate(zip(labels_with_ranges, quantity_list)):
if quantity_key in quantity_valid_ranges:
min_val, max_val = quantity_valid_ranges[quantity_key]
label = f"{base_label}\n∈ [{min_val:.6f}, {max_val:.6f}]"
else:
label = base_label
ax2.plot(t_values, quantity_values, label=label, linewidth=1)
ax2.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')
ax2.axhline(y=0, color='k', linestyle='--', linewidth=1.0, label='Prob Boundary (0,1)')
ax2.axhline(y=1, color='k', linestyle='--', linewidth=1.0)
if margin_tol > 0:
ax2.axhline(y=-margin_tol, color='gray', linestyle=':', linewidth=1.0, label=f'Margin (±{margin_tol})')
ax2.axhline(y=1+margin_tol, color='gray', linestyle=':', linewidth=1.0)
for i, (t_lo, t_hi) in enumerate(valid_t_regions):
t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max
if t_lo_plot < t_hi_plot:
label = 'Valid Gauge Region (t)' if i == 0 else ""
ax2.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)
ax2.set_xlabel(r"Gauge Parameter $(t)$")
ax2.set_ylabel("Derived Quantity Value")
ax2.set_title(f"Derived Quantities (Valid Interval: [{-margin_tol}, {1+margin_tol}])")
legend = ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=9,
labelspacing=1.2, handlelength=2)
ax2.grid(True, linestyle=':', alpha=0.6)
ax2.set_ylim(-0.1 - margin_tol, 1.1 + margin_tol)
# ===== Display DataFrames for valid intervals =====
print("="*80)
print("Valid Intervals for Derived Quantities (Readout & Back-action Errors)")
print("="*80)
if quantity_ranges_data:
df_quantities = pd.DataFrame(quantity_ranges_data)
display(df_quantities)
else:
print("No valid intervals found for derived quantities.")
print("\n" + "="*80)
print("Valid Intervals for Matrix Entries")
print("="*80)
if entry_ranges_data:
df_entries = pd.DataFrame(entry_ranges_data)
display(df_entries)
else:
print("No valid intervals found for matrix entries.")
if len(MCM_reference) > 0:
print("\nNote: Reference comparison format:")
print(" 'IN (±X.XXXe±YY)' - value is within range, absolute discrepancy from center")
print(" 'OUT above/below (±X.XXXe±YY)' - value is outside range, absolute discrepancy from center")
print("="*80 + "\n")
# ===== Third subplot: RMSE to reference instruments =====
best_match_instruments = []
if len(MCM_reference) > 0:
ax3 = axes[2]
for ref_idx, ref_inst in enumerate(MCM_reference):
ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
rmse_values = np.sqrt(np.mean((transformed_entries - ref_entries)**2, axis=1))
min_rmse_idx = np.argmin(rmse_values)
min_rmse = rmse_values[min_rmse_idx]
t_min_rmse = t_values[min_rmse_idx]
label = f"Ref {ref_idx+1}: min RMSE={min_rmse:.6e} at t={t_min_rmse:.4f}"
ax3.plot(t_values, rmse_values, label=label, linewidth=1.5)
ax3.plot(t_min_rmse, min_rmse, 'o', markersize=8)
M0_best, M1_best = gauge_transform_instrument_numerically(
MCM_to_transform.M0, MCM_to_transform.M1, t_min_rmse
)
best_match_instruments.append(Instrument2x2(M0=M0_best, M1=M1_best))
ax3.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')
for i, (t_lo, t_hi) in enumerate(valid_t_regions):
t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max
if t_lo_plot < t_hi_plot:
label = 'Valid Gauge Region (t)' if i == 0 else ""
ax3.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)
ax3.set_ylabel("RMSE to Reference")
ax3.set_title("RMSE Between Gauge-Transformed and Reference Instruments")
ax3.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=9)
ax3.grid(True, linestyle=':', alpha=0.6)
ax3.set_yscale('log')
# ===== Fourth subplot: Focused view =====
if has_focus_plot and focus_region_info is not None:
ax4 = axes[-1]
# Use the adjusted region from rebase_and_anchor_instrument
anchor_t = focus_region_info['anchor_t']
s_min, s_max = focus_region_info['adjusted_region']
# Convert back to absolute t values
# t3 = t1 + s - 2*t1*s => given t1=anchor_t, s in [s_min, s_max]
t_focus_exact_min = anchor_t + s_min - 2*anchor_t*s_min
t_focus_exact_max = anchor_t + s_max - 2*anchor_t*s_max
if t_focus_exact_min > t_focus_exact_max:
t_focus_exact_min, t_focus_exact_max = t_focus_exact_max, t_focus_exact_min
# Apply t_width_factor to the focus region
t_focus_center = (t_focus_exact_min + t_focus_exact_max) / 2
t_focus_half_span = (t_focus_exact_max - t_focus_exact_min) / 2
t_focus_min = t_focus_center - t_focus_half_span * t_width_factor
t_focus_max = t_focus_center + t_focus_half_span * t_width_factor
if verbose:
print(f"\nFocus region: anchor_t={anchor_t:.6f}, "
f"local s∈[{s_min:.6f}, {s_max:.6f}]")
print(f" Exact valid t∈[{t_focus_exact_min:.6f}, {t_focus_exact_max:.6f}]")
print(f" Plotted t∈[{t_focus_min:.6f}, {t_focus_max:.6f}] (with factor {t_width_factor:.2f})")
focus_mask = (t_values >= t_focus_min) & (t_values <= t_focus_max)
t_focus = t_values[focus_mask]
entries_focus = transformed_entries[focus_mask]
for i in range(8):
ax4.plot(t_focus, entries_focus[:, i], linewidth=0.8, alpha=0.5, label=labels[i])
for quantity_name, quantity_values in derived_quantities.items():
ax4.plot(t_focus, quantity_values[focus_mask], linewidth=1.5, label=quantity_name)
if len(MCM_reference) > 0:
for ref_idx, ref_inst in enumerate(MCM_reference):
ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
rmse_focus = np.sqrt(np.mean((entries_focus - ref_entries)**2, axis=1))
rmse_normalized = rmse_focus / (rmse_focus.max() + 1e-18)
ax4.plot(t_focus, rmse_normalized, linewidth=2, linestyle='--',
label=f"Ref {ref_idx+1} RMSE (normalized)\n original max={rmse_focus.max():.2e}")
if 0.5 >= t_focus_min and 0.5 <= t_focus_max:
ax4.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')
ax4.axhline(y=0, color='k', linestyle='--', linewidth=1.0, alpha=0.5)
ax4.axhline(y=1, color='k', linestyle='--', linewidth=1.0, alpha=0.5)
# Plot valid regions - only those overlapping with focus window
for i, (t_lo, t_hi) in enumerate(valid_t_regions):
if t_hi >= t_focus_min and t_lo <= t_focus_max:
plot_start = max(t_lo, t_focus_min)
plot_end = min(t_hi, t_focus_max)
label = 'Valid Gauge Region (t)' if i == 0 else ""
ax4.axvspan(float(plot_start), float(plot_end), color='green', alpha=0.2, label=label)
ax4.set_xlabel(r"Gauge Parameter $(t)$ [Focused View]")
ax4.set_ylabel("Quantity Values")
ax4.set_title(f"Focused View: t ∈ [{t_focus_min:.4f}, {t_focus_max:.4f}]\n"
f"Valid region: [{t_focus_exact_min:.4f}, {t_focus_exact_max:.4f}] "
f"(width factor: {t_width_factor:.2f})")
ax4.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=8, ncol=2)
ax4.grid(True, linestyle=':', alpha=0.6)
ax4.set_ylim(-0.1, 1.1)
ax4.set_xlim(t_focus_min, t_focus_max)
else:
axes[-1].set_xlabel(r"Gauge Parameter $(t)$")
plt.tight_layout()
plt.show()
center_t_regions = []
if center_instrument is not None:
center_t_regions = allowed_t_regions_for_list(
[center_instrument.M0, center_instrument.M1],
tol=1e-24,
margin_tol=margin_tol
)
return center_instrument, center_t_regions