Skip to content

Learning MCMs

Utility functions and Environment setup

  • Joint instrument per shot:

    \[ p_s^{(o,s')} :=\text{Prob}(\text{outcome}=o,\ \text{post-meas state}=s'\mid \text{pre-state}=s), \quad s,o,s'\in\{0,1\}, \]

    with column normalization \(\sum_{o,s'}p_s^{(o,s')}=1\) for each \(s\) (⇒ 6 Degrees of Freedom in total).

  • Outcome-indexed operators (“observable operators”)

    \[ M^{(o)} \in \mathbb{R}_{\ge 0}^{2\times 2},\quad [M^{(o)}]_{s',s}=p_s^{(o,s')}. \]

    Note: columns of \(M^{(o)}\) sum to \(A^{(o)}_s=\sum_{s'}p_s^{(o,s')}\) (not 1). The two \(M\) matrices explicitly are:

    \[ M^{(o=0)} = \begin{pmatrix} p_0^{(0, 0)} & p_1^{(0, 0)} \\ p_0^{(0, 1)} & p_1^{(0, 1)} \end{pmatrix} \text{ , } M^{(o=1)} = \begin{pmatrix} p_0^{(1, 0)} & p_1^{(1, 0)} \\ p_0^{(1, 1)} & p_1^{(1, 1)} \end{pmatrix} \]
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import numpy as np
import numpy.linalg as LA
from numpy.linalg import eig, inv, solve
import pandas as pd
from scipy.optimize import minimize, Bounds, LinearConstraint, basinhopping, differential_evolution, NonlinearConstraint
import os
from typing import Callable, Dict, Any, Tuple, List
import pickle
from itertools import product
from dataclasses import dataclass

from collections import Counter
from tqdm.notebook import tqdm
from IPython.display import display

rng = np.random.default_rng(7)


@dataclass
class Instrument2x2:
    """
    Holds the two symbol-operator matrices M0, M1 for a single-qubit instrument (mid-circuit measurement gadget).

    Each matrix is 2x2, with nonnegative entries, and the sum of the two matrices has columns summing to 1.

    M0 corresponds to measurement outcome 0, M1 to outcome 1.

    The entry M^o_[s',s] = p_s^{(o,s')} is the probability of obtaining outcome o and post-measurement state s'
    given starting state s.
    """

    M0: np.ndarray  # shape (2,2), nonnegative
    M1: np.ndarray  # shape (2,2), nonnegative

    def check_valid(self, atol: float = 1e-8) -> bool:
        """Validate nonnegativity and column-sum constraints."""
        assert self.M0.shape == (2, 2) and self.M1.shape == (2, 2)
        if (self.M0 < 0).any() or (self.M1 < 0).any():
            raise ValueError("Negative entries in M0/M1.")
        colsums = self.M0.sum(axis=0) + self.M1.sum(axis=0)
        if not np.allclose(colsums, np.ones(2), atol=atol):
            raise ValueError(f"Column sums of (M0+M1) must be 1, got {colsums}.")
        return True

    def reveal(self) -> int:
        """Print the instrument matrices in a readable format."""
        print("\nMCM Instrument:")
        print("\nM0 matrix (outcome 0):")
        print(
            f"  [[p_0^(0,0), p_1^(0,0)]] = [[{self.M0[0, 0]:.8f}, {self.M0[0, 1]:.8f}]]"
        )
        print(
            f"  [[p_0^(0,1), p_1^(0,1)]]   [[{self.M0[1, 0]:.8f}, {self.M0[1, 1]:.8f}]]"
        )
        print("\nM1 matrix (outcome 1):")
        print(
            f"  [[p_0^(1,0), p_1^(1,0)]] = [[{self.M1[0, 0]:.8f}, {self.M1[0, 1]:.8f}]]"
        )
        print(
            f"  [[p_0^(1,1), p_1^(1,1)]]   [[{self.M1[1, 0]:.8f}, {self.M1[1, 1]:.8f}]]"
        )
        return 0

    def show_readout_errors(self) -> List[float]:
        """
        Calculate and return the readout error rates for starting states |0> and |1>.
        Returns:
            List of readout error rates [error_rate_for_|0>, error_rate_for_|1>]
        """
        error_rate_0 = self.M1[0, 0] + self.M1[1, 0]  # Probability of getting outcome 1 when starting from |0>
        error_rate_1 = self.M0[0, 1] + self.M0[1, 1]  # Probability of getting outcome 0 when starting from |1>

        print(f"\nPrep 0 meas 1: {error_rate_0:.3e}")
        print(f"Prep 1 meas 0: {error_rate_1:.3e}")

        return [error_rate_0, error_rate_1]

    def show_backaction_errors(self) -> List[float]:
        """
        Calculate and return the back-action error rates for starting states |0> and |1>.
        Returns:
            List of back-action error rates [backaction_error_for_|0>, backaction_error_for_|1>]
        """
        backaction_error_0 = self.M0[1, 0] + self.M1[1, 0]  # Probability of ending in state |1> when starting from |0>
        backaction_error_1 = self.M0[0, 1] + self.M1[0, 1]  # Probability of ending in state |0> when starting from |1>

        print(f"\nPrep 0 ends in 1: {backaction_error_0:.3e}")
        print(f"Prep 1 ends in 0: {backaction_error_1:.3e}")

        return [backaction_error_0, backaction_error_1]


def make_instrument_from_columns(
    col0: np.ndarray,
    col1: np.ndarray
) -> Instrument2x2:
    """
    Build Instrument2x2 from two columns of 4 probabilities each:
    col_s = [p_s^{(0,0)}, p_s^{(0,1)}, p_s^{(1,0)}, p_s^{(1,1)}], sums to 1.
    Args:
        col0: Column for starting state |0>
        col1: Column for starting state |1>
    Returns:
        Instrument2x2 instance
    """
    col0 = np.asarray(col0, dtype=float).reshape(4)
    col1 = np.asarray(col1, dtype=float).reshape(4)
    if not np.isclose(col0.sum(), 1.0) or not np.isclose(col1.sum(), 1.0):
        raise ValueError("Each column-vector must sum to 1.")
    if (col0 < 0).any() or (col1 < 0).any():
        raise ValueError("Probabilities must be nonnegative.")

    M0 = np.zeros((2, 2))
    M1 = np.zeros((2, 2))

    # Column s=0, starting state |0>
    M0[0, 0], M0[1, 0], M1[0, 0], M1[1, 0] = col0
    # Column s=1, starting state |1>
    M0[0, 1], M0[1, 1], M1[0, 1], M1[1, 1] = col1

    instrument = Instrument2x2(M0=M0, M1=M1)
    instrument.check_valid()
    return instrument


def random_instrument(
    corr_strength: float = 0.2, 
    seed: int = 1, 
    fidelity: float = 0.85
) -> Instrument2x2:
    """
    Sample a random but 'physical' instrument with a tunable correlation flavor.
    We bias columns so that outcome=0 is more likely when s=0 and vice versa, and
    inject some back-action to create realistic correlations.
    For high-fidelity cases, p_0^(0,0) and p_1^(1,1) should be close to 1.

    Args:
        corr_strength: Tunes how much the Dirichlet distribution is biased.
        seed: Random seed for reproducibility.
        fidelity: The base probability for the 'correct' outcome (e.g., p_0(0,0)). Must be between 0.25 and 1.0.
    Returns:
        Instrument2x2 instance
    """
    if not (0.25 <= fidelity <= 1.0):
        raise ValueError("Fidelity must be between 0.25 and 1.0.")

    local_rng = np.random.default_rng(seed) if seed is not None else rng

    def biased_dirichlet(s: int) -> np.ndarray:
        """
        Sample a column of the instrument with bias towards correct outcomes.
        Args:
            s: The input state index (0 or 1).
        Returns:
            A column vector of 4 probabilities summing to 1.
        """
        # Base weights for [p_s(0,0), p_s(0,1), p_s(1,0), p_s(1,1)]
        off_diag_prob = (1.0 - fidelity) / 3.0
        if s == 0:  # Input state |0>
            # Expect outcome 0, state |0> (p_0(0,0) is high)
            base = np.array([fidelity, off_diag_prob, off_diag_prob, off_diag_prob])
        else:  # Input state |1>
            # Expect outcome 1, state |1> (p_1(1,1) is high)
            base = np.array([off_diag_prob, off_diag_prob, off_diag_prob, fidelity])

        bias = base ** (1.0 - corr_strength)
        return local_rng.dirichlet(1e-3 + 20 * bias)

    col0 = biased_dirichlet(0)
    col1 = biased_dirichlet(1)
    return make_instrument_from_columns(col0, col1)


def get_all_binary_strings(
    max_len: int
) -> List[str]:
    """
    Generates all binary strings up to a given maximum length.
    Args:
        max_len: Maximum length of binary strings to generate.
    Returns:
        List of binary strings.
    Example:
        get_all_binary_strings(2) returns ['0', '1', '00', '01', '10', '11']
    """
    strings = []
    for length in range(1, max_len + 1):
        for s_tuple in product("01", repeat=length):
            strings.append("".join(s_tuple))
    return strings


def calculate_exact_all_string_probabilities_from_v0_and_instrument(
    inst: Instrument2x2, 
    v0: np.ndarray, 
    max_len: int
) -> Dict[str, float]:
    """
    Calculate exact probabilities for all binary strings up to max_len (L)
    given an initial state distribution v0 and a MCM instrument.
    Exact Prob(w) = [1,1]^T M^{(w_L)} ... M^{(w_1)} v0 for all |w|<=L.
    Args:
        inst: Instrument2x2 instance representing the MCM instrument.
        v0: Initial state distribution as a numpy array of shape (2,).
        max_len: Maximum length of binary strings to consider.
    Returns:
        Dictionary mapping binary strings to their exact probabilities.
    """
    all_strings = get_all_binary_strings(max_len)
    all_probs = {}

    for single_string in all_strings:
        M_product = np.eye(2)
        for outcome_char in single_string:
            M_product = (inst.M0 if outcome_char == "0" else inst.M1) @ M_product
        all_probs[single_string] = np.sum(M_product @ v0)

    return all_probs


def monte_carlo_empirical_probabilities(
    inst: Instrument2x2, 
    v0: np.ndarray, 
    L: int, 
    shots: int, 
    seed: int = 1234
) -> Dict[str, float]:
    """
    Simulate the MCM process to obtain empirical probabilities for all binary strings up to length L.
    Args:
        inst: Instrument2x2 instance representing the MCM instrument.
        v0: Initial state distribution as a numpy array of shape (2,).
        L: Maximum length of binary strings to observe.
        shots: Number of Monte Carlo shots to simulate.
        seed: Random seed for reproducibility.
    Returns:
        Dictionary mapping binary strings to their empirical probabilities.
    """
    rng_local = np.random.default_rng(seed) if seed is not None else rng
    counts = Counter()
    total_prefixes = Counter({ell: 0 for ell in range(1, L + 1)})

    # Precompute per-column cumulative distributions over (o,s')
    cum_cdfs = {}
    for s in [0, 1]:
        c = np.array([inst.M0[0, s], inst.M0[1, s], inst.M1[0, s], inst.M1[1, s]])
        cum_cdfs[s] = np.cumsum(c)

    for _ in tqdm(range(shots), desc="Monte Carlo on strings"):

        # Z-twirl: flip v_initial with 50% probability
        v_after_twirling = v0 if rng_local.random() < 0.5 else v0[::-1]

        # sample the initial state from {0,1} according to v_after_twirling
        s = 0 if rng_local.random() < v_after_twirling[0] else 1

        # sample L steps of the MCM process
        readout_sequence = []
        for ell in range(1, L + 1):
            r = rng_local.random()
            cdf = cum_cdfs[s]
            idx = int(np.searchsorted(cdf, r, side="right"))
            if idx == 0:
                o, post_measurement_state = 0, 0
            elif idx == 1:
                o, post_measurement_state = 0, 1
            elif idx == 2:
                o, post_measurement_state = 1, 0
            else:
                o, post_measurement_state = 1, 1
            readout_sequence.append(str(o))
            s = post_measurement_state
            w = "".join(readout_sequence)
            counts[w] += 1
            total_prefixes[ell] += 1

    empirical_probabilities = {w: c / total_prefixes[len(w)] for w, c in counts.items()}
    return empirical_probabilities


def generate_and_cache_simulated_data(
    GT: Instrument2x2,
    num_seeds: int,
    shots_per_seed: int = 100_000,
    L: int = 4,
    use_same_v0_for_all_seeds: None | np.ndarray = np.array([0.5, 0.5]),
    seed_for_reproduce: int = 123,
    chosen_dir: str = "simulated_data"
) -> Tuple[Dict[str, Any], List[np.ndarray], List[Dict[str, float]]]:
    """
    Generate and cache simulated data for multiple seeds.
    Args:
        GT: Ground truth Instrument2x2 instance.
        num_seeds: Number of different random seeds to simulate.
        shots_per_seed: Number of Monte Carlo shots per seed.
        L: Maximum length of binary strings to simulate.
        use_same_v0_for_all_seeds: If provided, use this initial state distribution for all seeds.
        seed_for_reproduce: Seed for reproducibility of the entire simulation.
        chosen_dir: Directory to save the simulated data.
    Returns:
        Tuple of:
            - Experiment info dictionary.
            - List of initial state distributions (v0) for each seed.
            - List of empirical probability dictionaries for each seed.
    """

    v0_list = []
    empirical_probs_list = []
    experiment_info = {
        "GT": GT,
        "number_of_seeds": num_seeds,
        "shots_per_seed": shots_per_seed,
        "max_string_length_L": L,
        "use_same_v0_for_all_seeds": use_same_v0_for_all_seeds,
        "seed_for_reproduce": seed_for_reproduce
    }

    path_to_info = "experiment_info.pkl"
    path_to_v0_and_emp_probs = "v0_list_and_emp_probs_list.pkl"

    if os.path.exists(chosen_dir):
        print(f"Loading cached simulated data from {chosen_dir}...")
        print("Override experiment info with cached version.")
        with open(os.path.join(chosen_dir, path_to_info), "rb") as f:
            experiment_info = pickle.load(f)
        with open(os.path.join(chosen_dir, path_to_v0_and_emp_probs), "rb") as f:
            v0_list, empirical_probs_list = pickle.load(f)
        return experiment_info, v0_list, empirical_probs_list
    else:
        pass

    os.makedirs(chosen_dir)
    main_rng = np.random.default_rng(seed_for_reproduce)
    v_initial = main_rng.dirichlet([1.0, 1.0]) if use_same_v0_for_all_seeds is None else use_same_v0_for_all_seeds

    for i in tqdm(range(num_seeds), desc="Generating data for seeds"):
        monte_carlo_seed = main_rng.integers(0, 2**32 - 1)

        v0_list.append(v_initial)

        emp_probs = monte_carlo_empirical_probabilities(
            inst=GT,
            v0=v0_list[-1],
            L=L,
            shots=shots_per_seed,
            seed=monte_carlo_seed
        )
        empirical_probs_list.append(emp_probs)

    with open(os.path.join(chosen_dir, path_to_info), "wb") as f:
        pickle.dump(experiment_info, f)

    with open(os.path.join(chosen_dir, path_to_v0_and_emp_probs), "wb") as f:
        pickle.dump((v0_list, empirical_probs_list), f)

    return experiment_info, v0_list, empirical_probs_list


def calculate_average_probs_np(
    empirical_probs_list
) -> Dict[str, float]:
    """
    Calculates the average probabilities for each key across a list of dictionaries.

    Args:
        empirical_probs_list (list[dict]): 
            A list of dictionaries, where each
            dictionary maps string keys to
            probability floats.

    Returns:
        dict: 
            A single dictionary mapping each string key to its
            average probability across all input dictionaries.
    """
    # Return an empty dict if the input list is empty
    if not empirical_probs_list:
        return {}

    # 1. Get a fixed, sorted list of all keys from the first dictionary
    ordered_keys = sorted(empirical_probs_list[0].keys())

    # 2. Create a 2D NumPy array from the values
    data_array = np.array([
        [prob_dict[key] for key in ordered_keys] 
        for prob_dict in empirical_probs_list
    ])

    # 3. Calculate the mean of each column (axis=0)
    average_values = np.mean(data_array, axis=0)

    # 4. Zip the keys back with the average values to create the final dict
    average_probs_np = dict(zip(ordered_keys, average_values))

    return average_probs_np


def derived_constraints_from_empirical_probs(
    emp_probs: Dict[str, float]
) -> Tuple[float, float, float, float, None | float, None | float, float, float]:
    """
    Derive the constraint values from empirical probabilities.
    Args:
        emp_probs: Dictionary of empirical probabilities for binary strings.
    Returns:
        Tuple of derived constraint values:
            - trace_M0
            - det_M0
            - trace_M1
            - det_M1
            - trace_M0M1 and trace_M1M0 (they are the same but we can use both to check consistency)
            - prob_string_0 (need assumption of v0 being in maximally-mixed state)
            - prob_string_1 (need assumption of v0 being in maximally-mixed state)
    """

    # trace_M0 = (Prob('0') * Prob('00') - Prob('000')) / (Prob('0')^2 - Prob('00'))
    trace_M0 = (emp_probs.get('0', 0) * emp_probs.get('00', 0) - emp_probs.get('000', 0)) / \
                (emp_probs.get('0', 0)**2 - emp_probs.get('00', 0))

    # det_M0 = trace_M0 * Prob('0') - Prob('00')
    det_M0 = trace_M0 * emp_probs.get('0', 0) - emp_probs.get('00', 0)

    # trace_M1 = (Prob('1') * Prob('11') - Prob('111')) / (Prob('1')^2 - Prob('11'))
    trace_M1 = (emp_probs.get('1', 0) * emp_probs.get('11', 0) - emp_probs.get('111', 0)) / \
                (emp_probs.get('1', 0)**2 - emp_probs.get('11', 0))

    # det_M1 = trace_M1 * Prob('1') - Prob('11')
    det_M1 = trace_M1 * emp_probs.get('1', 0) - emp_probs.get('11', 0)

    trace_M0M1 = None
    trace_M1M0 = None

    # trace_M0M1 = (Prob('0101') + det_M1 * det_M0) / Prob('01')
    if '0101' in emp_probs and '01' in emp_probs and emp_probs['01'] != 0:
        trace_M0M1 = (emp_probs['0101'] + det_M1 * det_M0) / emp_probs['01']

    # trace_M1M0 = (Prob('1010') + det_M0 * det_M1) / Prob('10')
    if '1010' in emp_probs and '10' in emp_probs and emp_probs['10'] != 0:
        trace_M1M0 = (emp_probs['1010'] + det_M0 * det_M1) / emp_probs['10']

    return (trace_M0, 
            det_M0, 
            trace_M1, 
            det_M1, 
            trace_M0M1,
            trace_M1M0,
            emp_probs.get('0', 0),
            emp_probs.get('1', 0))

🟢 Gauge Continuum Characterization

Symbolic and Numeric Gauge Transformation

import sympy as sp

def construct_instrument_symbolically(
    a: sp.Symbol,
    b: sp.Symbol,
    c: sp.Symbol,
    e: sp.Symbol,
    f: sp.Symbol,
    g: sp.Symbol
) -> Tuple[sp.Matrix, sp.Matrix]:
    """
    Constructs symbolic instrument matrices M0 and M1 from 6 independent parameters.
    The parameters are probabilities for each column, which sum to 1.
    p_s^{(o,s')} are the probabilities.
    col0 = [p_0^{(0,0)}, p_0^{(0,1)}, p_0^{(1,0)}, p_0^{(1,1)}]
    col1 = [p_1^{(0,0)}, p_1^{(0,1)}, p_1^{(1,0)}, p_1^{(1,1)}]

    Args:
        a, b, c: Parameters for column 0. The fourth, p1_10, is 1 - a - b - c.
        e, f, g: Parameters for column 1. The fourth, p1_11, is 1 - e - f - g.

    Returns:
        A tuple of symbolic matrices (M0, M1).
    """
    # Column 0:
    p0_00 = a
    p0_10 = b
    p1_00 = c
    p1_10 = 1 - a - b - c

    # Column 1:
    p0_01 = e
    p0_11 = f
    p1_01 = g
    p1_11 = 1 - e - f - g

    M0 = sp.Matrix([[p0_00, p0_01], [p0_10, p0_11]])
    M1 = sp.Matrix([[p1_00, p1_01], [p1_10, p1_11]])

    return M0, M1


def gauge_transform_instrument_symbolically(
    M0: sp.Matrix,
    M1: sp.Matrix,
    t: sp.Symbol,
    ) -> Tuple[sp.Matrix, sp.Matrix]:
    """
    Apply a symbolic gauge transformation to the instrument matrices M0 and M1.
    M' = R^{-1} @ M @ R
    where R is a 2x2 matrix [[1 - t, t], [t, 1 - t]]

    Args:
        M0: Symbolic matrix for M0.
        M1: Symbolic matrix for M1.
        t: Symbolic variable for the gauge parameter.
    Returns:
        Transformed matrices M0' and M1'.
    """
    R = sp.Matrix([[1 - t, t], [t, 1 - t]])
    R_inv = R.inv() # Symbolic inverse
    M0_hat = R_inv * M0 * R
    M1_hat = R_inv * M1 * R
    return M0_hat, M1_hat


def gauge_transform_instrument_numerically(
    M0: np.ndarray,
    M1: np.ndarray,
    t: float
    ) -> Tuple[np.ndarray, np.ndarray]:
    """
    Apply a numerical gauge transformation to the instrument matrices M0 and M1.
    M' = R^{-1} @ M @ R
    where R is a 2x2 matrix [[1 - t, t], [t, 1 - t]]

    Args:
        M0: Numpy array for M0.
        M1: Numpy array for M1.
        t: Float value for the gauge parameter in [0,1].
    Returns:
        Transformed matrices M0' and M1'.
    """
    R = np.array([[1 - t, t], [t, 1 - t]])
    R_inv = LA.inv(R)
    M0_hat = R_inv @ M0 @ R
    M1_hat = R_inv @ M1 @ R
    return M0_hat, M1_hat
# This cell is to analytically check the gauge transformation don't change invariants.

# Define symbolic variables with assumptions
a = sp.Symbol('a', real=True)
b = sp.Symbol('b', real=True)
c = sp.Symbol('c', real=True)
e = sp.Symbol('e', real=True)
f = sp.Symbol('f', real=True)
g = sp.Symbol('g', real=True)
t = sp.Symbol('t', real=True)

# Add assumptions about the sums and the range of t
# These are not added to the symbols directly but used in subsequent symbolic manipulations.
# sp.assume(a + b + c <= 1) # sp.assume is for new-style assumptions, let's handle this logically.
# sp.assume(e + f + g <= 1)
# sp.assume(t <= 1)

# Construct the symbolic instrument
test_M0, test_M1 = construct_instrument_symbolically(a, b, c, e, f, g)

# Apply the gauge transformation
test_m0_prime, test_m1_prime = gauge_transform_instrument_symbolically(
    test_M0, test_M1, t
)

# Simplify the expressions for the transformed matrices
simplified_m0_prime = sp.simplify(test_m0_prime)
simplified_m1_prime = sp.simplify(test_m1_prime)

print("Original M0 and M1:")
display(test_M0, test_M1)

print("Original 2 matrices' det and trace:")
print(f"det(M0) = {sp.simplify(test_M0.det())}, trace(M0) = {sp.simplify(test_M0.trace())}")
print(f"det(M1) = {sp.simplify(test_M1.det())}, trace(M1) = {sp.simplify(test_M1.trace())}")

print("The original trace of M0M1 and M1M0:")
print(f"trace(M0M1) = {sp.simplify((test_M0 * test_M1).trace())}")
print(f"trace(M1M0) = {sp.simplify((test_M1 * test_M0).trace())}")

print("\nTransformed M0' and M1' (before simplification):")
display(test_m0_prime, test_m1_prime)

print("\nTransformed M0' and M1' (after simplification):")
display(simplified_m0_prime, simplified_m1_prime)

print("\nTransformed 2 matrices' det and trace:")
print(f"det(M0') = {sp.simplify(simplified_m0_prime.det())}, trace(M0') = {sp.simplify(simplified_m0_prime.trace())}")
print(f"det(M1') = {sp.simplify(simplified_m1_prime.det())}, trace(M1') = {sp.simplify(simplified_m1_prime.trace())}")

print("The transformed trace of M0'P1' and P1'M0':")
print(f"trace(M0'P1') = {sp.simplify((simplified_m0_prime * simplified_m1_prime).trace())}")
print(f"trace(M1'P0') = {sp.simplify((simplified_m1_prime * simplified_m0_prime).trace())}")

# Check the detailed forms of each element
print("\nDetailed forms of each element in M0' and M1' after transformation:")
for i in range(2):
    for j in range(2):
        elem_m0_prime = sp.simplify(simplified_m0_prime[i, j])
        elem_m1_prime = sp.simplify(simplified_m1_prime[i, j])
        print(f"M0'[{i},{j}]: {elem_m0_prime}")
        print(f"M1'[{i},{j}]: {elem_m1_prime}")
        print("-"*40)
1
Original M0 and M1:

\(\displaystyle \left[\begin{matrix}a & e\\b & f\end{matrix}\right]\)

\(\displaystyle \left[\begin{matrix}c & g\\- a - b - c + 1 & - e - f - g + 1\end{matrix}\right]\)

1
2
3
4
5
6
7
8
Original 2 matrices' det and trace:
det(M0) = a*f - b*e, trace(M0) = a + f
det(M1) = a*g + b*g - c*e - c*f + c - g, trace(M1) = c - e - f - g + 1
The original trace of M0M1 and M1M0:
trace(M0M1) = a*c + b*g - e*(a + b + c - 1) - f*(e + f + g - 1)
trace(M1M0) = a*c + b*g - e*(a + b + c - 1) - f*(e + f + g - 1)

Transformed M0' and M1' (before simplification):

\(\displaystyle \left[\begin{matrix}t \left(\frac{e \left(t - 1\right)}{2 t - 1} + \frac{f t}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{a \left(t - 1\right)}{2 t - 1} + \frac{b t}{2 t - 1}\right) & t \left(\frac{a \left(t - 1\right)}{2 t - 1} + \frac{b t}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{e \left(t - 1\right)}{2 t - 1} + \frac{f t}{2 t - 1}\right)\\t \left(\frac{e t}{2 t - 1} + \frac{f \left(t - 1\right)}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{a t}{2 t - 1} + \frac{b \left(t - 1\right)}{2 t - 1}\right) & t \left(\frac{a t}{2 t - 1} + \frac{b \left(t - 1\right)}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{e t}{2 t - 1} + \frac{f \left(t - 1\right)}{2 t - 1}\right)\end{matrix}\right]\)

\(\displaystyle \left[\begin{matrix}t \left(\frac{g \left(t - 1\right)}{2 t - 1} + \frac{t \left(- e - f - g + 1\right)}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{c \left(t - 1\right)}{2 t - 1} + \frac{t \left(- a - b - c + 1\right)}{2 t - 1}\right) & t \left(\frac{c \left(t - 1\right)}{2 t - 1} + \frac{t \left(- a - b - c + 1\right)}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{g \left(t - 1\right)}{2 t - 1} + \frac{t \left(- e - f - g + 1\right)}{2 t - 1}\right)\\t \left(\frac{g t}{2 t - 1} + \frac{\left(t - 1\right) \left(- e - f - g + 1\right)}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{c t}{2 t - 1} + \frac{\left(t - 1\right) \left(- a - b - c + 1\right)}{2 t - 1}\right) & t \left(\frac{c t}{2 t - 1} + \frac{\left(t - 1\right) \left(- a - b - c + 1\right)}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{g t}{2 t - 1} + \frac{\left(t - 1\right) \left(- e - f - g + 1\right)}{2 t - 1}\right)\end{matrix}\right]\)

1
Transformed M0' and M1' (after simplification):

\(\displaystyle \left[\begin{matrix}\frac{t \left(e \left(t - 1\right) + f t\right) - \left(t - 1\right) \left(a \left(t - 1\right) + b t\right)}{2 t - 1} & \frac{t \left(a \left(t - 1\right) + b t\right) - \left(t - 1\right) \left(e \left(t - 1\right) + f t\right)}{2 t - 1}\\\frac{t \left(e t + f \left(t - 1\right)\right) - \left(t - 1\right) \left(a t + b \left(t - 1\right)\right)}{2 t - 1} & \frac{t \left(a t + b \left(t - 1\right)\right) - \left(t - 1\right) \left(e t + f \left(t - 1\right)\right)}{2 t - 1}\end{matrix}\right]\)

\(\displaystyle \left[\begin{matrix}\frac{t \left(g \left(t - 1\right) - t \left(e + f + g - 1\right)\right) - \left(t - 1\right) \left(c \left(t - 1\right) - t \left(a + b + c - 1\right)\right)}{2 t - 1} & \frac{t \left(c \left(t - 1\right) - t \left(a + b + c - 1\right)\right) - \left(t - 1\right) \left(g \left(t - 1\right) - t \left(e + f + g - 1\right)\right)}{2 t - 1}\\\frac{t \left(g t - \left(t - 1\right) \left(e + f + g - 1\right)\right) - \left(t - 1\right) \left(c t - \left(t - 1\right) \left(a + b + c - 1\right)\right)}{2 t - 1} & \frac{t \left(c t - \left(t - 1\right) \left(a + b + c - 1\right)\right) - \left(t - 1\right) \left(g t - \left(t - 1\right) \left(e + f + g - 1\right)\right)}{2 t - 1}\end{matrix}\right]\)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
Transformed 2 matrices' det and trace:
det(M0') = a*f - b*e, trace(M0') = a + f
det(M1') = a*g + b*g - c*e - c*f + c - g, trace(M1') = c - e - f - g + 1
The transformed trace of M0'P1' and P1'M0':
trace(M0'P1') = a*c - a*e - b*e + b*g - c*e - e*f + e - f**2 - f*g + f
trace(M1'P0') = a*c - a*e - b*e + b*g - c*e - e*f + e - f**2 - f*g + f

Detailed forms of each element in M0' and M1' after transformation:
M0'[0,0]: (t*(e*(t - 1) + f*t) - (t - 1)*(a*(t - 1) + b*t))/(2*t - 1)
M1'[0,0]: (t*(g*(t - 1) - t*(e + f + g - 1)) - (t - 1)*(c*(t - 1) - t*(a + b + c - 1)))/(2*t - 1)
----------------------------------------
M0'[0,1]: (t*(a*(t - 1) + b*t) - (t - 1)*(e*(t - 1) + f*t))/(2*t - 1)
M1'[0,1]: (t*(c*(t - 1) - t*(a + b + c - 1)) - (t - 1)*(g*(t - 1) - t*(e + f + g - 1)))/(2*t - 1)
----------------------------------------
M0'[1,0]: (t*(e*t + f*(t - 1)) - (t - 1)*(a*t + b*(t - 1)))/(2*t - 1)
M1'[1,0]: (t*(g*t - (t - 1)*(e + f + g - 1)) - (t - 1)*(c*t - (t - 1)*(a + b + c - 1)))/(2*t - 1)
----------------------------------------
M0'[1,1]: (t*(a*t + b*(t - 1)) - (t - 1)*(e*t + f*(t - 1)))/(2*t - 1)
M1'[1,1]: (t*(c*t - (t - 1)*(a + b + c - 1)) - (t - 1)*(g*t - (t - 1)*(e + f + g - 1)))/(2*t - 1)
----------------------------------------

Close to ibm_pittsburgh's MCM collection

1
2
3
4
5
6
ibm_pittsburgh_mcm = [
    random_instrument(corr_strength=0.005, seed=3, fidelity=0.98),
    random_instrument(corr_strength=0.08, seed=37, fidelity=0.96), # this might be a little little bit off
    random_instrument(corr_strength=0.2, seed=1, fidelity=0.99), # this is good for demo of a ideally near-perfect instrument
    random_instrument(corr_strength=0.1, seed=98, fidelity=0.97), # this is a good demo for ibm_pittsburgh in the morning of Nov 6, 2025.
]
# 😃 Find seeds that produce GT instruments with readout errors in the desired range
target_min = 5e-4
target_max = 4e-3

good_seeds = []

for seed in range(100):
    GT_test = random_instrument(corr_strength=0.1, seed=seed, fidelity=0.97)
    readout_errors = GT_test.show_readout_errors()

    # Check if both readout errors are within the target range
    if (target_min <= readout_errors[0] <= target_max and 
        target_min <= readout_errors[1] <= target_max):
        good_seeds.append({
            'seed': seed,
            'prep_0_meas_1': readout_errors[0],
            'prep_1_meas_0': readout_errors[1]
        })

print(f"\nFound {len(good_seeds)} seeds with readout errors in range [{target_min:.1e}, {target_max:.1e}]:\n")
print("="*70)

for item in good_seeds:
    print(f"Seed {item['seed']:3d}: prep_0_meas_1 = {item['prep_0_meas_1']:.6f}, "
          f"prep_1_meas_0 = {item['prep_1_meas_0']:.6f}")

# Create a summary DataFrame
if good_seeds:
    df_good_seeds = pd.DataFrame(good_seeds)
    print("\n" + "="*70)
    print("Summary Statistics:")
    print("="*70)
    display(df_good_seeds.describe())
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
Prep 0 meas 1: 3.954e-02
Prep 1 meas 0: 5.291e-02

Prep 0 meas 1: 3.216e-02
Prep 1 meas 0: 2.313e-02

Prep 0 meas 1: 9.910e-03
Prep 1 meas 0: 8.432e-03

Prep 0 meas 1: 3.295e-03
Prep 1 meas 0: 5.632e-03

Prep 0 meas 1: 8.134e-02
Prep 1 meas 0: 5.731e-02

Prep 0 meas 1: 3.745e-03
Prep 1 meas 0: 3.655e-03

Prep 0 meas 1: 1.185e-02
Prep 1 meas 0: 2.531e-02

Prep 0 meas 1: 1.157e-03
Prep 1 meas 0: 2.906e-02

Prep 0 meas 1: 6.697e-02
Prep 1 meas 0: 8.373e-04

Prep 0 meas 1: 6.891e-02
Prep 1 meas 0: 4.451e-03

Prep 0 meas 1: 4.473e-02
Prep 1 meas 0: 4.349e-02

Prep 0 meas 1: 9.686e-02
Prep 1 meas 0: 1.365e-02

Prep 0 meas 1: 1.654e-02
Prep 1 meas 0: 6.799e-02

Prep 0 meas 1: 4.431e-02
Prep 1 meas 0: 3.722e-02

Prep 0 meas 1: 2.976e-02
Prep 1 meas 0: 1.988e-02

Prep 0 meas 1: 1.738e-02
Prep 1 meas 0: 2.324e-03

Prep 0 meas 1: 5.305e-02
Prep 1 meas 0: 6.418e-04

Prep 0 meas 1: 3.145e-03
Prep 1 meas 0: 7.141e-04

Prep 0 meas 1: 6.877e-02
Prep 1 meas 0: 3.752e-02

Prep 0 meas 1: 2.885e-02
Prep 1 meas 0: 5.163e-02

Prep 0 meas 1: 3.447e-03
Prep 1 meas 0: 6.137e-03

Prep 0 meas 1: 1.067e-01
Prep 1 meas 0: 5.064e-02

Prep 0 meas 1: 6.313e-03
Prep 1 meas 0: 7.208e-03

Prep 0 meas 1: 1.226e-02
Prep 1 meas 0: 2.436e-02

Prep 0 meas 1: 4.061e-02
Prep 1 meas 0: 3.393e-02

Prep 0 meas 1: 6.340e-05
Prep 1 meas 0: 1.964e-02

Prep 0 meas 1: 6.610e-02
Prep 1 meas 0: 1.379e-02

Prep 0 meas 1: 6.097e-02
Prep 1 meas 0: 3.602e-03

Prep 0 meas 1: 7.018e-02
Prep 1 meas 0: 1.482e-03

Prep 0 meas 1: 3.115e-03
Prep 1 meas 0: 3.088e-02

Prep 0 meas 1: 1.939e-02
Prep 1 meas 0: 2.923e-03

Prep 0 meas 1: 1.063e-02
Prep 1 meas 0: 2.176e-02

Prep 0 meas 1: 2.897e-02
Prep 1 meas 0: 1.569e-02

Prep 0 meas 1: 2.730e-02
Prep 1 meas 0: 8.904e-04

Prep 0 meas 1: 5.401e-02
Prep 1 meas 0: 8.231e-02

Prep 0 meas 1: 7.207e-02
Prep 1 meas 0: 7.186e-02

Prep 0 meas 1: 2.724e-02
Prep 1 meas 0: 1.543e-03

Prep 0 meas 1: 1.384e-03
Prep 1 meas 0: 3.894e-02

Prep 0 meas 1: 8.157e-02
Prep 1 meas 0: 4.408e-03

Prep 0 meas 1: 9.377e-03
Prep 1 meas 0: 7.330e-02

Prep 0 meas 1: 1.973e-02
Prep 1 meas 0: 6.033e-05

Prep 0 meas 1: 1.075e-02
Prep 1 meas 0: 1.736e-02

Prep 0 meas 1: 2.048e-02
Prep 1 meas 0: 2.122e-03

Prep 0 meas 1: 1.235e-02
Prep 1 meas 0: 8.039e-02

Prep 0 meas 1: 1.381e-03
Prep 1 meas 0: 1.287e-02

Prep 0 meas 1: 3.622e-02
Prep 1 meas 0: 1.845e-02

Prep 0 meas 1: 2.039e-02
Prep 1 meas 0: 1.941e-02

Prep 0 meas 1: 1.496e-01
Prep 1 meas 0: 1.872e-02

Prep 0 meas 1: 2.454e-02
Prep 1 meas 0: 2.842e-02

Prep 0 meas 1: 2.178e-02
Prep 1 meas 0: 1.982e-02

Prep 0 meas 1: 5.102e-04
Prep 1 meas 0: 4.857e-03

Prep 0 meas 1: 6.100e-02
Prep 1 meas 0: 6.834e-02

Prep 0 meas 1: 8.164e-04
Prep 1 meas 0: 1.268e-02

Prep 0 meas 1: 3.096e-02
Prep 1 meas 0: 1.077e-02

Prep 0 meas 1: 9.210e-02
Prep 1 meas 0: 7.279e-03

Prep 0 meas 1: 1.877e-02
Prep 1 meas 0: 1.567e-03

Prep 0 meas 1: 1.298e-02
Prep 1 meas 0: 9.034e-02

Prep 0 meas 1: 5.746e-03
Prep 1 meas 0: 1.529e-03

Prep 0 meas 1: 4.415e-02
Prep 1 meas 0: 1.281e-02

Prep 0 meas 1: 3.364e-04
Prep 1 meas 0: 4.238e-03

Prep 0 meas 1: 1.677e-01
Prep 1 meas 0: 2.087e-02

Prep 0 meas 1: 2.944e-02
Prep 1 meas 0: 1.529e-02

Prep 0 meas 1: 2.757e-02
Prep 1 meas 0: 3.409e-02

Prep 0 meas 1: 2.196e-03
Prep 1 meas 0: 1.356e-02

Prep 0 meas 1: 4.640e-02
Prep 1 meas 0: 1.075e-01

Prep 0 meas 1: 1.141e-01
Prep 1 meas 0: 9.404e-02

Prep 0 meas 1: 4.400e-03
Prep 1 meas 0: 3.722e-02

Prep 0 meas 1: 1.308e-02
Prep 1 meas 0: 6.963e-04

Prep 0 meas 1: 7.421e-03
Prep 1 meas 0: 8.472e-02

Prep 0 meas 1: 1.042e-02
Prep 1 meas 0: 2.376e-02

Prep 0 meas 1: 4.594e-02
Prep 1 meas 0: 1.437e-04

Prep 0 meas 1: 6.883e-03
Prep 1 meas 0: 2.595e-02

Prep 0 meas 1: 2.412e-02
Prep 1 meas 0: 7.091e-02

Prep 0 meas 1: 5.190e-02
Prep 1 meas 0: 2.750e-02

Prep 0 meas 1: 8.021e-03
Prep 1 meas 0: 2.017e-01

Prep 0 meas 1: 1.773e-01
Prep 1 meas 0: 3.356e-02

Prep 0 meas 1: 9.497e-03
Prep 1 meas 0: 5.684e-04

Prep 0 meas 1: 2.711e-02
Prep 1 meas 0: 3.854e-02

Prep 0 meas 1: 5.583e-03
Prep 1 meas 0: 3.309e-02

Prep 0 meas 1: 4.179e-03
Prep 1 meas 0: 7.091e-03

Prep 0 meas 1: 1.783e-02
Prep 1 meas 0: 4.603e-02

Prep 0 meas 1: 7.258e-03
Prep 1 meas 0: 6.195e-02

Prep 0 meas 1: 2.481e-03
Prep 1 meas 0: 4.681e-02

Prep 0 meas 1: 7.330e-02
Prep 1 meas 0: 6.523e-03

Prep 0 meas 1: 2.089e-03
Prep 1 meas 0: 4.543e-03

Prep 0 meas 1: 3.731e-02
Prep 1 meas 0: 3.813e-04

Prep 0 meas 1: 7.714e-03
Prep 1 meas 0: 1.226e-03

Prep 0 meas 1: 9.860e-03
Prep 1 meas 0: 1.639e-02

Prep 0 meas 1: 3.073e-03
Prep 1 meas 0: 1.018e-01

Prep 0 meas 1: 2.898e-02
Prep 1 meas 0: 8.594e-02

Prep 0 meas 1: 1.144e-02
Prep 1 meas 0: 2.055e-02

Prep 0 meas 1: 3.395e-02
Prep 1 meas 0: 3.625e-03

Prep 0 meas 1: 5.848e-06
Prep 1 meas 0: 6.457e-03

Prep 0 meas 1: 1.315e-01
Prep 1 meas 0: 4.339e-02

Prep 0 meas 1: 7.851e-03
Prep 1 meas 0: 5.244e-02

Prep 0 meas 1: 3.493e-03
Prep 1 meas 0: 5.407e-02

Prep 0 meas 1: 8.774e-03
Prep 1 meas 0: 9.827e-02

Prep 0 meas 1: 2.518e-04
Prep 1 meas 0: 8.628e-02

Prep 0 meas 1: 1.262e-03
Prep 1 meas 0: 1.506e-03

Prep 0 meas 1: 1.191e-02
Prep 1 meas 0: 7.654e-02

Found 3 seeds with readout errors in range [5.0e-04, 4.0e-03]:

======================================================================
Seed   5: prep_0_meas_1 = 0.003745, prep_1_meas_0 = 0.003655
Seed  17: prep_0_meas_1 = 0.003145, prep_1_meas_0 = 0.000714
Seed  98: prep_0_meas_1 = 0.001262, prep_1_meas_0 = 0.001506

======================================================================
Summary Statistics:
======================================================================
seed prep_0_meas_1 prep_1_meas_0
count 3.000000 3.000000 3.000000
mean 40.000000 0.002717 0.001958
std 50.586559 0.001295 0.001522
min 5.000000 0.001262 0.000714
25% 11.000000 0.002204 0.001110
50% 17.000000 0.003145 0.001506
75% 57.500000 0.003445 0.002580
max 98.000000 0.003745 0.003655
1
2
3
4
5
GT = random_instrument(corr_strength=0.1, seed=98, fidelity=0.97)
GT.reveal()

GT.show_readout_errors()
GT.show_backaction_errors()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99872032, 0.00132291]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00001741, 0.00018317]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00003101, 0.00078360]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00123126, 0.99771031]]

Prep 0 meas 1: 1.262e-03
Prep 1 meas 0: 1.506e-03

Prep 0 ends in 1: 1.249e-03
Prep 1 ends in 0: 2.107e-03





[np.float64(0.0012486678967334145), np.float64(0.0021065159600563614)]

Parametrizing the MCM matrices based on learnable values & Plotting Gauge Transformations

def reconstruct_instrument_from_invariants_mixed_det(
    trM0: float,
    detM0: float,
    trM1: float,
    detM1: float,
    S0: float,               # S0 = 2 * 1^T M0 v0 with v0 = (1,1)/2, S0 is the sum of all elements of M0
    gauge_p00: float = 1.0,  # a = M0[0,0]
    noise_tol: float = 1e-12
) -> List[Tuple[np.ndarray, np.ndarray]]:
    """
    Inputs: trM0, detM0, trM1, detM1, S0, and gauge M0[0,0]=a.
    the parametrization is:
        M0 = [[a, b],
              [c, d]]
        M1 = [[e, f],
              [g, h]]
    Solve M0 from:
        d = trM0 - a
        b + c = S0 - trM0
        b * c = a*d - detM0
    Generally we will have two set of solutions for (b,c).
    Then solve M1 from:
        e+g = 1 - a - c   (= u)
        f+h = 1 - b - d   (= v)
        e+h = trM1        (= T1)
        eh - fg = detM1
    The last equation gives h linearly: h = [detM1 + v(u - T1)]/(u - v) unless u=v. (this is unlikely for general MCMs)
    If u≈v, fall back to tr(M0M1) computed from detM1 by identity.
    """
    a  = float(gauge_p00)
    T0 = float(trM0)
    d  = T0 - a
    D0 = float(detM0)
    T1 = float(trM1)
    D1 = float(detM1)

    # M0 via quadratic in c
    S = float(S0) - T0           # S = b + c
    K = a*d - D0                 # K = b*c
    disc = S*S - 4.0*K
    if disc < -noise_tol:
        raise ValueError(f"Negative discriminant {disc:.3e}. Check inputs.")
    r = np.sqrt(max(disc, 0.0))
    c_roots = [(S + r)/2.0, (S - r)/2.0]

    sols: List[Tuple[np.ndarray, np.ndarray]] = []
    seen_c: List[float] = []

    for c in c_roots:
        if any(abs(c - cc) < noise_tol for cc in seen_c):
            continue
        seen_c.append(c)
        b = S - c

        u = 1.0 - a - c
        v = 1.0 - b - d

        if abs(u - v) > noise_tol:
            # Linear solve for h from det:
            h = (D1 + v*(u - T1)) / (u - v)
            e = T1 - h
            g = u - e
            f = v - h
        else:
            # Degenerate case: use tr(M0M1) from the identity and do a 4x4 linear solve
            X = D1 + D0 + (T0 - 1.0)*(T1 - 1.0)  # tr(M0M1)
            A = np.array([[1.0, 0.0, 1.0, 0.0],
                          [0.0, 1.0, 0.0, 1.0],
                          [1.0, 0.0, 0.0, 1.0],
                          [a,   c,   b,   d  ]], dtype=float)
            y = np.array([u, v, T1, X], dtype=float)
            try:
                e, f, g, h = np.linalg.solve(A, y)
            except np.linalg.LinAlgError:
                sol, *_ = np.linalg.lstsq(A, y, rcond=None)
                e, f, g, h = sol
                if np.linalg.norm(A @ sol - y, ord=np.inf) > 1e-9:
                    continue  # reject

        M0 = np.array([[a, b],
                       [c, d]], dtype=float)
        M1 = np.array([[e, f],
                       [g, h]], dtype=float)
        sols.append((M0, M1))

    return sols


def summarize_instrument(M0: np.ndarray, M1: np.ndarray) -> Dict[str, Any]:
    S = M0 + M1
    return dict(
        trM0=float(np.trace(M0)),
        detM0=float(np.linalg.det(M0)),
        trM1=float(np.trace(M1)),
        detM1=float(np.linalg.det(M1)),
        trM0M1=float(np.trace(M0 @ M1)),
        colsum0=float(S[:,0].sum()),
        colsum1=float(S[:,1].sum()),
        S0=float(M0.sum()),
    )
# Analytically compute the allowed ranges for $t$ starting from one of the solution pairs

def _abgd_from_M(M):
    a, b = M[0][0], M[0][1]
    c, d = M[1][0], M[1][1]
    alpha = 0.5*(a+b+c+d)
    beta  = 0.5*(a+c - b - d)
    gamma = 0.5*(a+b - c - d)
    delta = 0.5*(a - b - c + d)
    return alpha, beta, gamma, delta

def _intervals_quad(a, b, c, rel, tol=1e-24):
    """
    Solve { x : a x^2 + b x + c (rel) 0 }, rel in {'ge','le'}.
    Returns a list of (lo, hi) with lo/hi allowed to be ±inf.
    """
    inf = float('inf')
    # Linear or constant
    if abs(a) < tol:
        if abs(b) < tol:
            ok = (c >= -tol) if rel == 'ge' else (c <= tol)
            return [(-inf, inf)] if ok else []
        x0 = -c/b
        if b > 0:
            return [(x0, inf)] if rel == 'ge' else [(-inf, x0)]
        else:
            return [(-inf, x0)] if rel == 'ge' else [(x0, inf)]

    # Quadratic
    D = b*b - 4*a*c
    if D < -tol:
        # No real roots
        if a > 0:
            return [(-inf, inf)] if rel == 'ge' else []
        else:
            return [] if rel == 'ge' else [(-inf, inf)]
    if D < 0:  # treat tiny negatives as zero
        D = 0.0
    sqrtD = np.sqrt(D)
    r1 = (-b - sqrtD)/(2*a)
    r2 = (-b + sqrtD)/(2*a)
    if r1 > r2:
        r1, r2 = r2, r1

    if a > 0:
        if rel == 'ge':
            return [(-inf, r1), (r2, inf)] if D > tol else [(-inf, r1), (r2, inf)]
        else:
            return [(r1, r2)] if D > tol else [(r1, r2)]
    else:  # a < 0
        if rel == 'ge':
            return [(r1, r2)] if D > tol else [(r1, r2)]
        else:
            return [(-inf, r1), (r2, inf)] if D > tol else [(-inf, r1), (r2, inf)]

def _intersect_interval_lists(A, B, tol=1e-24):
    C = []
    for lo1, hi1 in A:
        for lo2, hi2 in B:
            lo = max(lo1, lo2)
            hi = min(hi1, hi2)
            if lo <= hi + tol:
                C.append((lo, hi))
    if not C:
        return []
    C.sort(key=lambda x: x[0])
    merged = []
    cur_lo, cur_hi = C[0]
    for lo, hi in C[1:]:
        if lo <= cur_hi + tol:
            cur_hi = max(cur_hi, hi)
        else:
            merged.append((cur_lo, cur_hi))
            cur_lo, cur_hi = lo, hi
    merged.append((cur_lo, cur_hi))
    return merged

def _intersect_many(list_of_interval_lists, tol=1e-24):
    if not list_of_interval_lists:
        return []
    out = list_of_interval_lists[0]
    for L in list_of_interval_lists[1:]:
        out = _intersect_interval_lists(out, L, tol=tol)
        if not out:
            break
    return out

def _D_intervals_for_M(M, D_positive=True, tol=1e-24, margin_tol=0.0):
    """
    Return feasible D-intervals for a single M.
    For D>0: enforce -margin_tol <= a',b',c',d' <= 1+margin_tol.
    For D<0: same but inequality directions flip (because 2D < 0).
    """
    alpha, beta, gamma, delta = _abgd_from_M(M)

    # Adjust linear coefficients for margin
    # Lower bound: M' >= -eps => Q/(2D) >= -eps => Q + 2*eps*D >= 0 (for D>0)
    # Upper bound: M' <= 1+eps => Q/(2D) <= 1+eps => Q - 2*(1+eps)*D <= 0 => Q - 2D - 2*eps*D <= 0 (for D>0)
    shift = 2.0 * margin_tol

    # Build the eight quadratic constraints (lower/upper for a',b',c',d').
    # For D>0:
    #  a':  β D^2 + (α+δ+shift)D + γ ≥ 0;   β D^2 + (α+δ-2-shift)D + γ ≤ 0
    #  b': -β D^2 + (α-δ+shift)D + γ ≥ 0;  -β D^2 + (α-δ-2-shift)D + γ ≤ 0
    #  c':  β D^2 + (α-δ+shift)D - γ ≥ 0;   β D^2 + (α-δ-2-shift)D - γ ≤ 0
    #  d': -β D^2 + (α+δ+shift)D - γ ≥ 0;  -β D^2 + (α+δ-2-shift)D - γ ≤ 0
    if D_positive:
        polys = [
            ( beta,  alpha+delta+shift,     gamma, 'ge'),
            ( beta,  alpha+delta-2-shift,   gamma, 'le'),
            (-beta,  alpha-delta+shift,     gamma, 'ge'),
            (-beta,  alpha-delta-2-shift,   gamma, 'le'),
            ( beta,  alpha-delta+shift,    -gamma, 'ge'),
            ( beta,  alpha-delta-2-shift,  -gamma, 'le'),
            (-beta,  alpha+delta+shift,    -gamma, 'ge'),
            (-beta,  alpha+delta-2-shift,  -gamma, 'le'),
        ]
        domain = (0.0, float('inf'))
    else:
        # For D<0, reverse directions
        polys = [
            ( beta,  alpha+delta+shift,     gamma, 'le'),
            ( beta,  alpha+delta-2-shift,   gamma, 'ge'),
            (-beta,  alpha-delta+shift,     gamma, 'le'),
            (-beta,  alpha-delta-2-shift,   gamma, 'ge'),
            ( beta,  alpha-delta+shift,    -gamma, 'le'),
            ( beta,  alpha-delta-2-shift,  -gamma, 'ge'),
            (-beta,  alpha+delta+shift,    -gamma, 'le'),
            (-beta,  alpha+delta-2-shift,  -gamma, 'ge'),
        ]
        domain = (-float('inf'), 0.0)

    all_sets = []
    for a,b,c,rel in polys:
        S = _intervals_quad(a,b,c, rel, tol=tol)
        if not S:
            return []
        all_sets.append(S)

    inter = _intersect_many(all_sets, tol=tol)
    if not inter:
        return []
    return _intersect_interval_lists(inter, [domain], tol=tol)

def _merge_intervals(intervals, tol=1e-24):
    if not intervals:
        return []
    intervals = sorted(intervals, key=lambda x: x[0])
    merged = []
    lo, hi = intervals[0]
    for L, H in intervals[1:]:
        if L <= hi + tol:
            hi = max(hi, H)
        else:
            merged.append((lo, hi))
            lo, hi = L, H
    merged.append((lo, hi))
    return merged

def _map_D_to_t(D_intervals, tol=1e-24):
    """
    t = (1 - D)/2. Monotone decreasing map.
    Map each [D_lo, D_hi] to [t_lo, t_hi] with t_lo=(1-D_hi)/2, t_hi=(1-D_lo)/2.
    Excludes t=1/2 automatically because D-intervals never include D=0 (domain split).
    """
    out = []
    for D_lo, D_hi in D_intervals:
        t_lo = (1 - D_hi)/2
        t_hi = (1 - D_lo)/2
        out.append((t_lo, t_hi))
    return _merge_intervals(out, tol=tol)

def allowed_t_regions_for_M(M, tol=1e-24, margin_tol=0.0):
    """
    Input: M (2x2 numpy array, real).
    Output: list of (t_min, t_max) intervals such that R(t)^{-1} @ M @ R(t) has all entries in [-margin_tol, 1+margin_tol].
            Endpoints may be ±np.inf. t=1/2 is excluded by construction.
    """
    D_pos = _D_intervals_for_M(M, D_positive=True,  tol=tol, margin_tol=margin_tol)
    D_neg = _D_intervals_for_M(M, D_positive=False, tol=tol, margin_tol=margin_tol)
    t_sets = []
    if D_pos:
        t_sets += _map_D_to_t(D_pos, tol=tol)
    if D_neg:
        t_sets += _map_D_to_t(D_neg, tol=tol)
    return _merge_intervals(t_sets, tol=tol)

# Optional helper to intersect across multiple matrices at once:
def allowed_t_regions_for_list(M_list, tol=1e-24, margin_tol=0.0):
    """
    Intersect allowed regions across several matrices.
    """
    regions = None
    for M in M_list:
        r = allowed_t_regions_for_M(M, tol=tol, margin_tol=margin_tol)
        if regions is None:
            regions = r
        else:
            regions = _intersect_interval_lists(regions, r, tol=tol)
        if not regions:
            return []
    return _merge_intervals(regions, tol=tol)
def _R(t):
    """R(t) = [[1-t, t],[t, 1-t]], valid for t != 1/2."""
    return np.array([[1.0 - t, t],
                    [t, 1.0 - t]], dtype=float)

def _R_inv(t):
    """Closed-form inverse of R(t)."""
    D = 1.0 - 2.0*t
    if abs(D) < 1e-16:
        raise ValueError("t is too close to 1/2; R(t) is nearly singular.")
    return (1.0/D) * np.array([[1.0 - t, -t],
                            [-t, 1.0 - t]], dtype=float)

def _gauge_transform_pair(M0, M1, t):
    """Return (M0', M1') = (R^-1 M0 R, R^-1 M1 R)."""
    Rin = _R_inv(t); R = _R(t)
    return Rin @ M0 @ R, Rin @ M1 @ R

def rebase_and_anchor_instrument(M0, M1, t_regions, p00_min=0.5, tol=1e-24):
    """
    For each allowed t-interval [t_lo, t_hi], anchor at t_lo, transform (M0,M1),
    and return the anchored pair with its adjusted local-gauge interval [0, s_max],
    where s_max = (t_hi - t_lo)/(1 - 2*t_lo).

    Parameters
    ----------
    M0, M1 : (2,2) np.ndarray
        Trial instrument matrices.
    t_regions : list of (t_lo, t_hi)
        Allowed t-intervals from allowed_t_regions_for_list([M0, M1]).
    p00_min : float
        Required lower bound for anchored [M0']_{00}.
    tol : float
        Numerical tolerance for boundary checks.

    Returns
    -------
    results : list of dict
        Each dict contains:
            - 'anchor_t' : float
            - 'anchored_M0' : (2,2) np.ndarray
            - 'anchored_M1' : (2,2) np.ndarray
            - 'adjusted_region' : (0.0, s_max)  # can have s_max < 0 when anchor_t > 1/2
            - 'ok_p00' : bool  # whether [anchored_M0]_{00} >= p00_min (within tol)
            - 'p00' : float
    """
    results = []
    for (t_lo, t_hi) in t_regions:
        # Anchor at the left endpoint
        t1 = float(t_lo)
        M0a, M1a = _gauge_transform_pair(M0, M1, t1)

        # Check p_0^(0,0) >= p00_min
        p00 = float(M0a[0, 0])
        ok = (p00 >= p00_min - tol)

        # Compute s_max via group law: t3 = t1 + s - 2*t1*s
        denom = 1.0 - 2.0*t1
        if abs(denom) < 1e-16:
            # This should not occur if t_regions came from the feasibility solver
            s_max = np.sign(t_hi - t_lo) * np.inf
        else:
            s_max = (float(t_hi) - t1) / denom

        # Report adjusted local-gauge interval as [0, s_max]
        results.append({
            'anchor_t': t1,
            'anchored_M0': M0a,
            'anchored_M1': M1a,
            'adjusted_region': (0.0, s_max),
            'ok_p00': ok,
            'p00': p00
        })
    return results
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
def plot_gauge_transformation_effects(
    MCM_to_transform, 
    t_width_factor: float = 1.20, 
    verbose: bool = False, 
    MCM_reference: List[Instrument2x2] = None,
    resolution: float = 1e5,
    p00_min: float = 0.5,
    margin_tol: float = 0.0,
    render_plots: bool = True,
):
    """
    Analyzes and plots the effect of a gauge transformation on an instrument.

    This function takes an instrument, applies a gauge transformation over a range
    of the gauge parameter 't', and plots how each of the 8 matrix entries evolves.
    It uses analytical methods to determine valid gauge parameter regions.

    Note: t=0.5 is excluded from the analysis as the gauge transformation matrix
    is non-invertible at that point.

    Args:
        MCM_to_transform: An Instrument2x2 object to be transformed.
        t_width_factor: Factor to scale the plotting range around valid regions. Default 1.20.
                        1.00 means plot exactly the valid regions, >1.00 adds padding.
        verbose: If True, print detailed information about valid intervals. Default False.
        MCM_reference: List of reference Instrument2x2 objects to compare RMSE against. Default [].
        resolution: Number of points to sample per unit t-range.
        p00_min: Minimum value for M0[0,0] when identifying focus regions. Default 0.5.
        margin_tol: Margin tolerance for allowed regions. Default 0.0.
        render_plots: If True, generate and show plots. If False, skip plotting. Default True.

    Returns:
        Tuple containing:
            - Instrument2x2 object constructed from center points of valid entry ranges, or None if no valid ranges.
            - List of valid t-regions for the center_instrument (relative to itself).
    """
    if MCM_reference is None:
        MCM_reference = []

    # Helper for formatting value column
    def fmt_val_err(min_v, max_v):
        c = (min_v + max_v) / 2
        h = (max_v - min_v) / 2

        # Determine exponent from the larger of abs(c) or abs(h) to avoid tiny numbers if c is near zero
        ref = abs(c) if abs(c) > 0 else abs(h)
        if ref == 0:
            return "(0.000 ± 0.000)e+0"

        exponent = int(np.floor(np.log10(ref)))
        scale = 10.0 ** (-exponent)

        c_s = c * scale
        h_s = h * scale

        return f"({c_s:.3f} ± {h_s:.3f})e{exponent:+d}"

    if verbose:
        print("Original Instrument to be transformed:")
        MCM_to_transform.reveal()

    # Use analytical method to find valid t-regions
    valid_t_regions = allowed_t_regions_for_list(
        [MCM_to_transform.M0, MCM_to_transform.M1], 
        tol=1e-24,
        margin_tol=margin_tol
    )

    if not valid_t_regions:
        raise ValueError(f"No valid gauge parameter regions found for this instrument (margin_tol={margin_tol}).")

    if verbose:
        print(f"\nAnalytically determined valid t-regions (total: {len(valid_t_regions)}) with margin {margin_tol}:")
        for i, (t_lo, t_hi) in enumerate(valid_t_regions, 1):
            print(f"  Region {i}: t ∈ [{t_lo:.6f}, {t_hi:.6f}] (width: {t_hi - t_lo:.6f})")

    # Determine plotting range based on valid regions and width factor
    all_t_mins = [r[0] for r in valid_t_regions if not np.isinf(r[0])]
    all_t_maxs = [r[1] for r in valid_t_regions if not np.isinf(r[1])]

    if all_t_mins and all_t_maxs:
        t_plot_min = min(all_t_mins)
        t_plot_max = max(all_t_maxs)
        t_center = (t_plot_min + t_plot_max) / 2
        t_half_span = (t_plot_max - t_plot_min) / 2

        # Apply width factor
        t_plot_min = t_center - t_half_span * t_width_factor
        t_plot_max = t_center + t_half_span * t_width_factor
    else:
        # Fallback if regions are unbounded
        t_plot_min = -0.5
        t_plot_max = 1.5

    # Ensure we don't include t=0.5 in our sampling
    if abs(t_plot_min - 0.5) < 1e-6:
        t_plot_min = 0.5 - 1e-6
    if abs(t_plot_max - 0.5) < 1e-6:
        t_plot_max = 0.5 + 1e-6

    # Generate t values for plotting, excluding t=0.5
    n_points = int(resolution * (t_plot_max - t_plot_min))
    if t_plot_min < 0.5 < t_plot_max:
        t_values_left = np.linspace(t_plot_min, 0.5 - 1e-6, n_points // 2)
        t_values_right = np.linspace(0.5 + 1e-6, t_plot_max, n_points // 2)
        t_values = np.concatenate([t_values_left, t_values_right])
    elif t_plot_max < 0.5:
        t_values = np.linspace(t_plot_min, t_plot_max, n_points)
    else:
        t_values = np.linspace(t_plot_min, t_plot_max, n_points)

    # Store the 8 entries of the transformed instrument for each value of t
    transformed_entries = []

    for t_val in t_values:
        M0_prime, M1_prime = gauge_transform_instrument_numerically(
            MCM_to_transform.M0, MCM_to_transform.M1, t_val
        )
        entries = np.concatenate((M0_prime.flatten(), M1_prime.flatten()))
        transformed_entries.append(entries)

    transformed_entries = np.array(transformed_entries)

    # --- Generate dedicated statistics samples from the first valid region ---
    # This ensures table values are independent of plotting width factor.
    stats_entries = None
    if len(valid_t_regions) > 0:
        t_lo_stat, t_hi_stat = valid_t_regions[0]

        # Handle potential infinite bounds for stats (clip to reasonable range if needed)
        t_start_stat = t_lo_stat if not np.isinf(t_lo_stat) else -5.0
        t_end_stat = t_hi_stat if not np.isinf(t_hi_stat) else 5.0

        width_stat = t_end_stat - t_start_stat
        # Ensure sufficient points for statistics
        n_stats = max(200, int(resolution * width_stat))
        t_stats = np.linspace(t_start_stat, t_end_stat, n_stats)

        stats_entries_list = []
        for t_val in t_stats:
            M0_p, M1_p = gauge_transform_instrument_numerically(
                MCM_to_transform.M0, MCM_to_transform.M1, t_val
            )
            stats_entries_list.append(np.concatenate((M0_p.flatten(), M1_p.flatten())))
        stats_entries = np.array(stats_entries_list)
    # -----------------------------------------------------------------------

    # Use rebase_and_anchor_instrument to find focus regions with ok_p00
    rebased_results = rebase_and_anchor_instrument(
        MCM_to_transform.M0, 
        MCM_to_transform.M1, 
        valid_t_regions, 
        p00_min=p00_min
    )

    # Find first region with ok_p00 for focus plot
    focus_region_info = None
    for res in rebased_results:
        if res['ok_p00']:
            focus_region_info = res
            break

    # Determine number of subplots
    has_focus_plot = focus_region_info is not None or len(MCM_reference) > 0

    if render_plots:
        if has_focus_plot:
            fig = plt.figure(figsize=(14, 28))
            gs = fig.add_gridspec(4, 1, height_ratios=[1.2, 1.2, 1.2, 2.4], hspace=0.15)

            ax1 = fig.add_subplot(gs[0])
            ax2 = fig.add_subplot(gs[1], sharex=ax1)
            axes = [ax1, ax2]

            if len(MCM_reference) > 0:
                ax3 = fig.add_subplot(gs[2], sharex=ax1)
                axes.append(ax3)

            ax4 = fig.add_subplot(gs[3])
            axes.append(ax4)
        else:
            fig, axes = plt.subplots(2, 1, figsize=(14, 14), sharex=True)
            if isinstance(axes, plt.Axes):
                axes = [axes]

        ax1 = axes[0]
        ax2 = axes[1]

        # ===== First subplot: Individual matrix entries =====
        labels = [
            r"$M^0$[0,0] = $p_0^{(0,0)}$", 
            r"$M^0$[0,1] = $p_1^{(0,0)}$", 
            r"$M^0$[1,0] = $p_0^{(0,1)}$",
            r"$M^0$[1,1] = $p_1^{(0,1)}$",
            r"$M^1$[0,0] = $p_0^{(1,0)}$",
            r"$M^1$[0,1] = $p_1^{(1,0)}$",
            r"$M^1$[1,0] = $p_0^{(1,1)}$",
            r"$M^1$[1,1] = $p_1^{(1,1)}$"
        ]

        for i in range(8):
            ax1.plot(t_values, transformed_entries[:, i], linewidth=1.0, label=labels[i])

        ax1.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')
        ax1.axhline(y=0, color='k', linestyle='--', linewidth=1.0, label='Prob Boundary (0,1)')
        ax1.axhline(y=1, color='k', linestyle='--', linewidth=1.0)

        if margin_tol > 0:
            ax1.axhline(y=-margin_tol, color='gray', linestyle=':', linewidth=1.0, label=f'Margin (±{margin_tol})')
            ax1.axhline(y=1+margin_tol, color='gray', linestyle=':', linewidth=1.0)

        # Highlight valid regions using analytical results
        for i, (t_lo, t_hi) in enumerate(valid_t_regions):
            # Clip to plotting range
            t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
            t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max

            if t_lo_plot < t_hi_plot:
                label = 'Valid Gauge Region (t)' if i == 0 else ""
                ax1.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)

        ax1.set_ylabel("Value of Instrument Matrix Entry")
        ax1.set_title(f"Evolution of Instrument Entries (Valid Interval: [{-margin_tol}, {1+margin_tol}])")
        ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        ax1.grid(True, linestyle=':', alpha=0.6)
        ax1.set_ylim(max(np.min(transformed_entries), -0.1 - margin_tol), min(np.max(transformed_entries), 1.1 + margin_tol))

    # ===== Compute valid ranges for matrix entries =====
    entry_labels = [
        "M^0[0,0] = p_0^(0,0)",
        "M^0[0,1] = p_1^(0,0)",
        "M^0[1,0] = p_0^(0,1)",
        "M^0[1,1] = p_1^(0,1)",
        "M^1[0,0] = p_0^(1,0)",
        "M^1[0,1] = p_1^(1,0)",
        "M^1[1,0] = p_0^(1,1)",
        "M^1[1,1] = p_1^(1,1)"
    ]

    entry_ranges_data = []
    center_values = []  # Store center values for constructing return instrument

    for entry_idx, entry_label in enumerate(entry_labels):
        if stats_entries is not None:
            block_values = stats_entries[:, entry_idx]
            min_val = np.min(block_values)
            max_val = np.max(block_values)
            width = max_val - min_val
            center = (min_val + max_val) / 2

            center_values.append(center)

            row_data = {
                'Entry': entry_label,
                'Min': f"{min_val:.8f}",
                'Max': f"{max_val:.8f}",
                'Width': f"{width:.5e}",
                'Value': fmt_val_err(min_val, max_val)
            }

            # Add comparison columns for each reference instrument
            for ref_idx, ref_inst in enumerate(MCM_reference):
                ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
                ref_val = ref_entries[entry_idx]

                # Calculate absolute discrepancy from center
                abs_disc = ref_val - center

                # Check if reference value is within range
                if min_val <= ref_val <= max_val:
                    row_data[f'Ref{ref_idx+1}'] = f"IN ({abs_disc:+.5e})"
                else:
                    direction = "above" if ref_val > max_val else "below"
                    row_data[f'Ref{ref_idx+1}'] = f"OUT {direction} ({abs_disc:+.5e})"

            entry_ranges_data.append(row_data)

    # Construct Instrument2x2 from center values
    center_instrument = None
    if len(center_values) == 8:
        M0_center = np.array([[center_values[0], center_values[1]],
                              [center_values[2], center_values[3]]], dtype=float)
        M1_center = np.array([[center_values[4], center_values[5]],
                              [center_values[6], center_values[7]]], dtype=float)
        center_instrument = Instrument2x2(M0=M0_center, M1=M1_center)

    # ===== Second subplot: Derived quantities =====
    prep0_meas1 = transformed_entries[:, 4] + transformed_entries[:, 6]
    prep1_meas0 = transformed_entries[:, 1] + transformed_entries[:, 3]
    prep0_excite = transformed_entries[:, 2] + transformed_entries[:, 6]
    prep1_decay = transformed_entries[:, 1] + transformed_entries[:, 5]

    derived_quantities = {
        "prep 0 meas 1": prep0_meas1,
        "prep 1 meas 0": prep1_meas0,
        "prep 0 excite to 1": prep0_excite,
        "prep 1 decay to 0": prep1_decay
    }

    # Calculate derived quantities for stats
    stats_derived = {}
    if stats_entries is not None:
        s_prep0_meas1 = stats_entries[:, 4] + stats_entries[:, 6]
        s_prep1_meas0 = stats_entries[:, 1] + stats_entries[:, 3]
        s_prep0_excite = stats_entries[:, 2] + stats_entries[:, 6]
        s_prep1_decay = stats_entries[:, 1] + stats_entries[:, 5]

        stats_derived = {
            "prep 0 meas 1": s_prep0_meas1,
            "prep 1 meas 0": s_prep1_meas0,
            "prep 0 excite to 1": s_prep0_excite,
            "prep 1 decay to 0": s_prep1_decay
        }
    # print(derived_quantities)
    # print(stats_derived)

    quantity_valid_ranges = {}
    quantity_ranges_data = []
    for quantity_name, quantity_values in derived_quantities.items():
        if quantity_name in stats_derived:
            block_values = stats_derived[quantity_name]

            min_val = np.min(block_values)
            max_val = np.max(block_values)
            width = max_val - min_val
            center = (min_val + max_val) / 2
            quantity_valid_ranges[quantity_name] = (min_val, max_val)

            row_data = {
                'Quantity': quantity_name,
                'Min': f"{min_val:.8f}",
                'Max': f"{max_val:.8f}",
                'Width': f"{width:.5e}",
                'Value': fmt_val_err(min_val, max_val)
            }

            # Add comparison columns for each reference instrument
            for ref_idx, ref_inst in enumerate(MCM_reference):
                # Compute reference quantity value
                if quantity_name == "prep 0 meas 1":
                    ref_val = ref_inst.M1[0, 0] + ref_inst.M1[1, 0]
                elif quantity_name == "prep 1 meas 0":
                    ref_val = ref_inst.M0[0, 1] + ref_inst.M0[1, 1]
                elif quantity_name == "prep 0 excite to 1":
                    ref_val = ref_inst.M0[1, 0] + ref_inst.M1[1, 0]
                elif quantity_name == "prep 1 decay to 0":
                    ref_val = ref_inst.M0[0, 1] + ref_inst.M1[0, 1]
                else:
                    ref_val = 0.0

                # Calculate absolute discrepancy from center
                abs_disc = ref_val - center

                # Check if reference value is within range
                if min_val <= ref_val <= max_val:
                    row_data[f'Ref{ref_idx+1}'] = f"IN ({abs_disc:+.5e})"
                else:
                    direction = "above" if ref_val > max_val else "below"
                    row_data[f'Ref{ref_idx+1}'] = f"OUT {direction} ({abs_disc:+.5e})"

            quantity_ranges_data.append(row_data)

    if render_plots:
        # Plot the derived quantities
        labels_with_ranges = [
            (r"prep 0 meas 1: $p_0^{(1,0)} + p_0^{(1,1)}$", "prep 0 meas 1"),
            (r"prep 1 meas 0: $p_1^{(0,0)} + p_1^{(0,1)}$", "prep 1 meas 0"),
            (r"prep 0 excite to 1: $p_0^{(0,1)} + p_0^{(1,1)}$", "prep 0 excite to 1"),
            (r"prep 1 decay to 0: $p_1^{(0,0)} + p_1^{(1,0)}$", "prep 1 decay to 0")
        ]

        quantity_list = list(derived_quantities.items())
        for idx, ((base_label, quantity_key), (quantity_name, quantity_values)) in enumerate(zip(labels_with_ranges, quantity_list)):
            if quantity_key in quantity_valid_ranges:
                min_val, max_val = quantity_valid_ranges[quantity_key]
                label = f"{base_label}\n∈ [{min_val:.6f}, {max_val:.6f}]"
            else:
                label = base_label
            ax2.plot(t_values, quantity_values, label=label, linewidth=1)

        ax2.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')
        ax2.axhline(y=0, color='k', linestyle='--', linewidth=1.0, label='Prob Boundary (0,1)')
        ax2.axhline(y=1, color='k', linestyle='--', linewidth=1.0)

        if margin_tol > 0:
            ax2.axhline(y=-margin_tol, color='gray', linestyle=':', linewidth=1.0, label=f'Margin (±{margin_tol})')
            ax2.axhline(y=1+margin_tol, color='gray', linestyle=':', linewidth=1.0)

        for i, (t_lo, t_hi) in enumerate(valid_t_regions):
            t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
            t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max

            if t_lo_plot < t_hi_plot:
                label = 'Valid Gauge Region (t)' if i == 0 else ""
                ax2.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)

        ax2.set_xlabel(r"Gauge Parameter $(t)$")
        ax2.set_ylabel("Derived Quantity Value")
        ax2.set_title(f"Derived Quantities (Valid Interval: [{-margin_tol}, {1+margin_tol}])")

        legend = ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=9, 
                            labelspacing=1.2, handlelength=2)
        ax2.grid(True, linestyle=':', alpha=0.6)
        ax2.set_ylim(-0.1 - margin_tol, 1.1 + margin_tol)

    # ===== Display DataFrames for valid intervals =====
    print("="*80)
    print("Valid Intervals for Derived Quantities (Readout & Back-action Errors)")
    print("="*80)
    if quantity_ranges_data:
        df_quantities = pd.DataFrame(quantity_ranges_data)
        display(df_quantities)
    else:
        print("No valid intervals found for derived quantities.")

    print("\n" + "="*80)
    print("Valid Intervals for Matrix Entries")
    print("="*80)
    if entry_ranges_data:
        df_entries = pd.DataFrame(entry_ranges_data)
        display(df_entries)
    else:
        print("No valid intervals found for matrix entries.")

    if len(MCM_reference) > 0:
        print("\nNote: Reference comparison format:")
        print("  'IN (±X.XXXe±YY)' - value is within range, absolute discrepancy from center")
        print("  'OUT above/below (±X.XXXe±YY)' - value is outside range, absolute discrepancy from center")
    print("="*80 + "\n")

    # ===== Third subplot: RMSE to reference instruments =====
    best_match_instruments = []

    if len(MCM_reference) > 0:
        if render_plots:
            ax3 = axes[2]

        for ref_idx, ref_inst in enumerate(MCM_reference):
            ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
            rmse_values = np.sqrt(np.mean((transformed_entries - ref_entries)**2, axis=1))

            min_rmse_idx = np.argmin(rmse_values)
            min_rmse = rmse_values[min_rmse_idx]
            t_min_rmse = t_values[min_rmse_idx]

            if render_plots:
                label = f"Ref {ref_idx+1}: min RMSE={min_rmse:.6e} at t={t_min_rmse:.4f}"
                ax3.plot(t_values, rmse_values, label=label, linewidth=1.5)

                ax3.plot(t_min_rmse, min_rmse, 'o', markersize=8)

            M0_best, M1_best = gauge_transform_instrument_numerically(
                MCM_to_transform.M0, MCM_to_transform.M1, t_min_rmse
            )
            best_match_instruments.append(Instrument2x2(M0=M0_best, M1=M1_best))

        if render_plots:
            ax3.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')

            for i, (t_lo, t_hi) in enumerate(valid_t_regions):
                t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
                t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max

                if t_lo_plot < t_hi_plot:
                    label = 'Valid Gauge Region (t)' if i == 0 else ""
                    ax3.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)

            ax3.set_ylabel("RMSE to Reference")
            ax3.set_title("RMSE Between Gauge-Transformed and Reference Instruments")
            ax3.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=9)
            ax3.grid(True, linestyle=':', alpha=0.6)
            ax3.set_yscale('log')

    # ===== Fourth subplot: Focused view =====
    if has_focus_plot and focus_region_info is not None:
        if render_plots:
            ax4 = axes[-1]

        # Use the adjusted region from rebase_and_anchor_instrument
        anchor_t = focus_region_info['anchor_t']
        s_min, s_max = focus_region_info['adjusted_region']

        # Convert back to absolute t values
        # t3 = t1 + s - 2*t1*s => given t1=anchor_t, s in [s_min, s_max]
        t_focus_exact_min = anchor_t + s_min - 2*anchor_t*s_min
        t_focus_exact_max = anchor_t + s_max - 2*anchor_t*s_max

        if t_focus_exact_min > t_focus_exact_max:
            t_focus_exact_min, t_focus_exact_max = t_focus_exact_max, t_focus_exact_min

        # Apply t_width_factor to the focus region
        t_focus_center = (t_focus_exact_min + t_focus_exact_max) / 2
        t_focus_half_span = (t_focus_exact_max - t_focus_exact_min) / 2

        t_focus_min = t_focus_center - t_focus_half_span * t_width_factor
        t_focus_max = t_focus_center + t_focus_half_span * t_width_factor

        if verbose:
            print(f"\nFocus region: anchor_t={anchor_t:.6f}, "
                  f"local s∈[{s_min:.6f}, {s_max:.6f}]")
            print(f"  Exact valid t∈[{t_focus_exact_min:.6f}, {t_focus_exact_max:.6f}]")
            print(f"  Plotted t∈[{t_focus_min:.6f}, {t_focus_max:.6f}] (with factor {t_width_factor:.2f})")

        if render_plots:
            focus_mask = (t_values >= t_focus_min) & (t_values <= t_focus_max)
            t_focus = t_values[focus_mask]
            entries_focus = transformed_entries[focus_mask]

            for i in range(8):
                ax4.plot(t_focus, entries_focus[:, i], linewidth=0.8, alpha=0.5, label=labels[i])

            for quantity_name, quantity_values in derived_quantities.items():
                ax4.plot(t_focus, quantity_values[focus_mask], linewidth=1.5, label=quantity_name)

            if len(MCM_reference) > 0:
                for ref_idx, ref_inst in enumerate(MCM_reference):
                    ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
                    rmse_focus = np.sqrt(np.mean((entries_focus - ref_entries)**2, axis=1))
                    rmse_normalized = rmse_focus / (rmse_focus.max() + 1e-18)
                    ax4.plot(t_focus, rmse_normalized, linewidth=2, linestyle='--', 
                            label=f"Ref {ref_idx+1} RMSE (normalized)\n original max={rmse_focus.max():.2e}")

            if 0.5 >= t_focus_min and 0.5 <= t_focus_max:
                ax4.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')

            ax4.axhline(y=0, color='k', linestyle='--', linewidth=1.0, alpha=0.5)
            ax4.axhline(y=1, color='k', linestyle='--', linewidth=1.0, alpha=0.5)

            # Plot valid regions - only those overlapping with focus window
            for i, (t_lo, t_hi) in enumerate(valid_t_regions):
                if t_hi >= t_focus_min and t_lo <= t_focus_max:
                    plot_start = max(t_lo, t_focus_min)
                    plot_end = min(t_hi, t_focus_max)
                    label = 'Valid Gauge Region (t)' if i == 0 else ""
                    ax4.axvspan(float(plot_start), float(plot_end), color='green', alpha=0.2, label=label)

            ax4.set_xlabel(r"Gauge Parameter $(t)$ [Focused View]")
            ax4.set_ylabel("Quantity Values")
            ax4.set_title(f"Focused View: t ∈ [{t_focus_min:.4f}, {t_focus_max:.4f}]\n"
                        f"Valid region: [{t_focus_exact_min:.4f}, {t_focus_exact_max:.4f}] "
                        f"(width factor: {t_width_factor:.2f})")
            ax4.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=8, ncol=2)
            ax4.grid(True, linestyle=':', alpha=0.6)
            ax4.set_ylim(-0.1, 1.1)
            ax4.set_xlim(t_focus_min, t_focus_max)
    else:
        if render_plots:
            axes[-1].set_xlabel(r"Gauge Parameter $(t)$")

    if render_plots:
        plt.tight_layout()
        plt.show()

    center_t_regions = []
    if center_instrument is not None:
        center_t_regions = allowed_t_regions_for_list(
            [center_instrument.M0, center_instrument.M1], 
            tol=1e-24,
            margin_tol=margin_tol
        )
    # filter the center_t_regions such that the regions that has the smallest sum of absolute values for 2 end points is chosen
    if center_t_regions:
        center_t_regions = sorted(center_t_regions, key=lambda x: abs(x[0]) + abs(x[1]))
        center_t_regions = center_t_regions[0]

    return center_instrument, center_t_regions
def refine_bounds_with_sp_error(
    reconstructed_MCM: Instrument2x2 | Any,
    t_bound: Tuple[float, float] | Any,
    prob_meas_0_val: float | Any,
    margin_tol: float = 1e-4,
    verbose: bool = False
) -> Tuple[Tuple[float, float], Tuple[float, float]]:
    """
    Refines the valid gauge parameter interval (t_bound) by incorporating the constraint
    that the state preparation error (epsilon) must be physical (within [0, 1]).

    Logic:
        P(outcome=1) = P(1|0_noisy)
                     = P(1|0_ideal) * (1 - epsilon) + P(1|1_ideal) * epsilon

        Where P(1|s_ideal) are functions of the gauge parameter t derived from reconstructed_MCM.
        We solve for epsilon(t) and find the subset of t_bound where 0 <= epsilon(t) <= 1.

    Args:
        reconstructed_MCM: The Instrument2x2 object at the center of the gauge (t=0).
        t_bound: The current valid interval (t_min, t_max) derived from MCM positivity.
        prob_meas_0_val: The empirical probability of measuring '0' (from prob dictionary).
        margin_tol: Tolerance for the physical bounds of epsilon (allows -tol to 1+tol).
        verbose: If True, prints detailed intermediate values.

    Returns:
        Tuple containing:
            - refined_t_bound: (new_t_min, new_t_max)
            - epsilon_bound: (eps_min, eps_max) over the refined t interval.
    """

    # 1. Extract baseline readout probabilities (at t=0) for outcome 1
    # M1 is the matrix for outcome 1.
    # Column 0 sum is P(1|0) at t=0.
    # Column 1 sum is P(1|1) at t=0.
    M1 = reconstructed_MCM.M1
    c10 = M1[0, 0] + M1[1, 0]  # P(1|0) baseline
    c11 = M1[0, 1] + M1[1, 1]  # P(1|1) baseline

    # We use P(outcome=1) for the calculation
    P_obs = 1.0 - prob_meas_0_val

    # Delta c = P(1|1) - P(1|0). Usually close to 1 for a good measurement.
    delta_c = c11 - c10

    if abs(delta_c) < 1e-9:
        if verbose:
            print("Warning: Instrument has negligible distinguishing power (delta_c ~ 0). Cannot bound epsilon.")
        return t_bound, (0.0, 0.0) # Return original bounds

    # 2. Define the mapping functions
    # Derived formula: epsilon(t) = [P_obs - c10 - t*delta_c] / [delta_c * (1 - 2t)]
    def get_epsilon(t):
        denom = delta_c * (1.0 - 2.0*t)
        if abs(denom) < 1e-15:
            return np.inf # Singularity at t=0.5
        num = P_obs - c10 - t * delta_c
        return num / denom

    # Inverse mapping: t(epsilon)
    def get_t(eps):
        denom = delta_c * (1.0 - 2.0*eps)
        if abs(denom) < 1e-15:
            return np.inf
        num = P_obs - c10 - eps * delta_c
        return num / denom

    # 3. Analyze bounds
    t_min_in, t_max_in = t_bound

    # Check for singularity crossing (unlikely for valid physical regions)
    if (t_min_in < 0.5 and t_max_in > 0.5):
        if verbose:
            print("Warning: Input t_bound crosses singularity t=0.5. Skipping refinement.")
        return t_bound, (get_epsilon(t_min_in), get_epsilon(t_max_in))

    # Calculate epsilon at the current t boundaries
    eps_at_min = get_epsilon(t_min_in)
    eps_at_max = get_epsilon(t_max_in)

    # Determine current epsilon range (order matters for min/max)
    curr_eps_min = min(eps_at_min, eps_at_max)
    curr_eps_max = max(eps_at_min, eps_at_max)

    # Define target physical range for epsilon
    target_eps_min = 0.0 - margin_tol
    target_eps_max = 1.0 + margin_tol

    if verbose:
        print(f"--- Refine Bounds Debug ---")
        print(f"P(outcome=1): {P_obs:.6f}")
        print(f"Baseline P(1|0): {c10:.6f}, P(1|1): {c11:.6f}")
        print(f"Input t_bound: [{t_min_in:.6f}, {t_max_in:.6f}]")
        print(f"Epsilon range on input t: [{curr_eps_min:.6f}, {curr_eps_max:.6f}]")
        print(f"Target epsilon range: [{target_eps_min:.6f}, {target_eps_max:.6f}]")

    # 4. Intersect ranges
    # We clip the current epsilon range to the target physical range
    refined_eps_min = max(curr_eps_min, target_eps_min)
    refined_eps_max = min(curr_eps_max, target_eps_max)

    # Check for empty intersection
    if refined_eps_min > refined_eps_max:
        print("Error: No valid solution found. State preparation error constraints are incompatible with MCM constraints.")
        return t_bound, (np.nan, np.nan)

    # 5. Map back to t
    # Since the mapping is monotonic (hyperbola branch), we map the refined epsilon endpoints back to t
    t_lim_1 = get_t(refined_eps_min)
    t_lim_2 = get_t(refined_eps_max)

    new_t_min = min(t_lim_1, t_lim_2)
    new_t_max = max(t_lim_1, t_lim_2)

    # Numerical safety clip to original bounds
    new_t_min = max(new_t_min, t_min_in)
    new_t_max = min(new_t_max, t_max_in)

    # Re-evaluate epsilon at exact new t bounds for consistency
    final_eps_min = min(get_epsilon(new_t_min), get_epsilon(new_t_max))
    final_eps_max = max(get_epsilon(new_t_min), get_epsilon(new_t_max))

    if verbose:
        print(f"Refined t_bound: [{new_t_min:.6f}, {new_t_max:.6f}]")
        print(f"Refined epsilon: [{final_eps_min:.6f}, {final_eps_max:.6f}]")
        print(f"---------------------------")

    return (new_t_min, new_t_max), (final_eps_min, final_eps_max)
GT = random_instrument(corr_strength=0.1, seed=59, fidelity=0.8)
GT.reveal()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.63686009, 0.14318407]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.26631059, 0.15142864]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.02935295, 0.03553478]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.06747637, 0.66985250]]





0
ibm_pittsburgh_mcm[3].reveal()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99872032, 0.00132291]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00001741, 0.00018317]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00003101, 0.00078360]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00123126, 0.99771031]]





0
chosen_sol_index = 0

GT = ibm_pittsburgh_mcm[3]
GT.reveal()

prob_dict_GT = calculate_exact_all_string_probabilities_from_v0_and_instrument(
    inst = GT,
    v0=np.array([0.5, 0.5]),
    max_len=3
)

M0t, M1t = GT.M0, GT.M1
invrnts = summarize_instrument(M0t, M1t)
sols = reconstruct_instrument_from_invariants_mixed_det(
    trM0=invrnts["trM0"], detM0=invrnts["detM0"],
    trM1=invrnts["trM1"], detM1=invrnts["detM1"],
    S0=invrnts["S0"], gauge_p00=1.0
)

display('')

reconstruct_instrument, _ = plot_gauge_transformation_effects(
    Instrument2x2(M0=sols[chosen_sol_index][0], M1=sols[chosen_sol_index][1]), 
    t_width_factor=1.1, 
    verbose=False, 
    MCM_reference=[GT]
    )

if reconstruct_instrument is not None:
    print("Reconstructed Instrument from center points of valid entry ranges:")
    reconstruct_instrument.reveal()

print("\nGround Truth Instrument:")
GT.reveal()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99872032, 0.00132291]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00001741, 0.00018317]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00003101, 0.00078360]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00123126, 0.99771031]]



''


================================================================================
Valid Intervals for Derived Quantities (Readout & Back-action Errors)
================================================================================
Quantity Min Max Width Value Ref1
0 prep 0 meas 1 0.00003005 0.00127966 1.24961e-03 (6.549 ± 6.248)e-4 IN (+6.07420e-04)
1 prep 1 meas 0 0.00027386 0.00152347 1.24961e-03 (8.987 ± 6.248)e-4 IN (+6.07420e-04)
2 prep 0 excite to 1 0.00124865 0.00124973 1.07234e-06 (1.249 ± 0.001)e-3 IN (-5.21213e-07)
3 prep 1 decay to 0 0.00210546 0.00210653 1.07234e-06 (2.106 ± 0.001)e-3 IN (+5.21213e-07)
1
2
3
================================================================================
Valid Intervals for Matrix Entries
================================================================================
Entry Min Max Width Value Ref1
0 M^0[0,0] = p_0^(0,0) 0.99871989 0.99872034 4.49770e-07 (9.987 ± 0.000)e-1 IN (+2.01820e-07)
1 M^0[0,1] = p_1^(0,0) 0.00009060 0.00134032 1.24973e-03 (7.155 ± 6.249)e-4 IN (+6.07453e-04)
2 M^0[1,0] = p_0^(0,1) 0.00000000 0.00124973 1.24973e-03 (6.249 ± 6.249)e-4 IN (-6.07453e-04)
3 M^0[1,1] = p_1^(0,1) 0.00018315 0.00018360 4.49770e-07 (1.834 ± 0.002)e-4 IN (-2.01820e-07)
4 M^1[0,0] = p_0^(1,0) 0.00003005 0.00003106 1.01727e-06 (3.056 ± 0.051)e-5 IN (+4.58425e-07)
5 M^1[0,1] = p_1^(1,0) 0.00076621 0.00201486 1.24865e-03 (1.391 ± 0.624)e-3 IN (-6.06932e-04)
6 M^1[1,0] = p_0^(1,1) -0.00000000 0.00124865 1.24865e-03 (6.243 ± 6.243)e-4 IN (+6.06932e-04)
7 M^1[1,1] = p_1^(1,1) 0.99771026 0.99771128 1.01727e-06 (9.977 ± 0.000)e-1 IN (-4.58425e-07)
1
2
3
4
5
6
7
8
9
Note: Reference comparison format:
  'IN (±X.XXXe±YY)' - value is within range, absolute discrepancy from center
  'OUT above/below (±X.XXXe±YY)' - value is outside range, absolute discrepancy from center
================================================================================



/var/folders/dk/_dd6ng8n3yv_crq6vrdt8qzc0000gn/T/ipykernel_57129/1006529938.py:551: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  plt.tight_layout()

svg

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
Reconstructed Instrument from center points of valid entry ranges:

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99872012, 0.00071546]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00062486, 0.00018337]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00003056, 0.00139053]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00062433, 0.99771077]]

Ground Truth Instrument:

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99872032, 0.00132291]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00001741, 0.00018317]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00003101, 0.00078360]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00123126, 0.99771031]]





0
1
2
3
4
5
6
ibm_pittsburgh_mcm = [
    random_instrument(corr_strength=0.005, seed=3, fidelity=0.98),
    random_instrument(corr_strength=0.08, seed=37, fidelity=0.96), # this might be a little little bit off
    random_instrument(corr_strength=0.2, seed=1, fidelity=0.99), # this is good for demo of a ideally near-perfect instrument
    random_instrument(corr_strength=0.1, seed=98, fidelity=0.97), # this is a good demo for ibm_pittsburgh in the morning of Nov 6, 2025.
]
1
2
3
4
5
GT = random_instrument(corr_strength=0.1, seed=98, fidelity=0.97)
GT.reveal()

GT.show_readout_errors()
GT.show_backaction_errors()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99872032, 0.00132291]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00001741, 0.00018317]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00003101, 0.00078360]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00123126, 0.99771031]]

Prep 0 meas 1: 1.262e-03
Prep 1 meas 0: 1.506e-03

Prep 0 ends in 1: 1.249e-03
Prep 1 ends in 0: 2.107e-03





[np.float64(0.0012486678967334145), np.float64(0.0021065159600563614)]
chosen_sol_index = 0

reconstructed_MCM, _ = plot_gauge_transformation_effects(
    Instrument2x2(M0=sols[chosen_sol_index][0], M1=sols[chosen_sol_index][1]), 
    t_width_factor=1.1, 
    verbose=False, 
    MCM_reference=[GT],
    resolution=1e5,
    )

print("Original Instrument:")
GT.reveal()

print("\nReconstructed Instrument from center points of valid entry ranges:")
if reconstructed_MCM is not None:
    reconstructed_MCM.reveal()
1
2
3
================================================================================
Valid Intervals for Derived Quantities (Readout & Back-action Errors)
================================================================================
Quantity Min Max Width Value Ref1
0 prep 0 meas 1 0.00003005 0.00127966 1.24961e-03 (6.549 ± 6.248)e-4 IN (+6.07420e-04)
1 prep 1 meas 0 0.00027386 0.00152347 1.24961e-03 (8.987 ± 6.248)e-4 IN (+6.07420e-04)
2 prep 0 excite to 1 0.00124865 0.00124973 1.07234e-06 (1.249 ± 0.001)e-3 IN (-5.21213e-07)
3 prep 1 decay to 0 0.00210546 0.00210653 1.07234e-06 (2.106 ± 0.001)e-3 IN (+5.21213e-07)
1
2
3
================================================================================
Valid Intervals for Matrix Entries
================================================================================
Entry Min Max Width Value Ref1
0 M^0[0,0] = p_0^(0,0) 0.99871989 0.99872034 4.49770e-07 (9.987 ± 0.000)e-1 IN (+2.01820e-07)
1 M^0[0,1] = p_1^(0,0) 0.00009060 0.00134032 1.24973e-03 (7.155 ± 6.249)e-4 IN (+6.07453e-04)
2 M^0[1,0] = p_0^(0,1) 0.00000000 0.00124973 1.24973e-03 (6.249 ± 6.249)e-4 IN (-6.07453e-04)
3 M^0[1,1] = p_1^(0,1) 0.00018315 0.00018360 4.49770e-07 (1.834 ± 0.002)e-4 IN (-2.01820e-07)
4 M^1[0,0] = p_0^(1,0) 0.00003005 0.00003106 1.01727e-06 (3.056 ± 0.051)e-5 IN (+4.58425e-07)
5 M^1[0,1] = p_1^(1,0) 0.00076621 0.00201486 1.24865e-03 (1.391 ± 0.624)e-3 IN (-6.06932e-04)
6 M^1[1,0] = p_0^(1,1) -0.00000000 0.00124865 1.24865e-03 (6.243 ± 6.243)e-4 IN (+6.06932e-04)
7 M^1[1,1] = p_1^(1,1) 0.99771026 0.99771128 1.01727e-06 (9.977 ± 0.000)e-1 IN (-4.58425e-07)
1
2
3
4
5
6
7
8
9
Note: Reference comparison format:
  'IN (±X.XXXe±YY)' - value is within range, absolute discrepancy from center
  'OUT above/below (±X.XXXe±YY)' - value is outside range, absolute discrepancy from center
================================================================================



/var/folders/dk/_dd6ng8n3yv_crq6vrdt8qzc0000gn/T/ipykernel_57129/1006529938.py:551: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  plt.tight_layout()

svg

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Original Instrument:

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99872032, 0.00132291]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00001741, 0.00018317]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00003101, 0.00078360]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00123126, 0.99771031]]

Reconstructed Instrument from center points of valid entry ranges:

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99872012, 0.00071546]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00062486, 0.00018337]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00003056, 0.00139053]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00062433, 0.99771077]]

Monte Carlo

selected_dir = "simulated_data_406_twirl"
selected_dir = "simulated_data_405"

GT_instrument = random_instrument(corr_strength=0.08, seed=2, fidelity=0.97)
# Generate and cache simulated data
experiment_info, v0_list, empirical_probs_list = generate_and_cache_simulated_data(
    GT=GT_instrument,
    num_seeds=50,
    shots_per_seed=int(10_000_000 / 0.1),
    L=4,
    use_same_v0_for_all_seeds=np.array([0.2, 0.8]),
    seed_for_reproduce=123,
    chosen_dir=selected_dir
)
for key in experiment_info.keys():
    print(key)
    # print(f"{key}: {experiment_info[key]}")

# GT_instrument = experiment_info["GT"]
GT_instrument = Instrument2x2(M0=experiment_info["GT"].P0, M1=experiment_info["GT"].P1)
GT_instrument.reveal()


GT_inv = summarize_instrument(GT_instrument.M0, GT_instrument.M1)
# print(GT_inv)
emp_inv = derived_constraints_from_empirical_probs(calculate_average_probs_np(empirical_probs_list))
# print(emp_inv)


sols = reconstruct_instrument_from_invariants_mixed_det(
    trM0 = emp_inv[0],
    detM0 = emp_inv[1],
    trM1 = emp_inv[2],
    detM1 = emp_inv[3],
    S0 = emp_inv[6] * 2, # S0 = sum of all elements of M0 = 2 * probability of observing 0 from v0=(0.5,0.5)
    gauge_p00 = 1.0
)

# # cheating
# sols = reconstruct_instrument_from_invariants_mixed_det(
#     trM0 = GT_instrument.M0.trace(),
#     detM0 = np.linalg.det(GT_instrument.M0),
#     trM1 = GT_instrument.M1.trace(),
#     detM1 = np.linalg.det(GT_instrument.M1),
#     S0 = GT_instrument.M0.sum(),
#     gauge_p00 = 1.0
# )

print(f"\nthere are {len(sols)} solutions from empirical data.")

# display a table to compare the args used for reconstruction from both GT and empirical data
comparison_table = pd.DataFrame({
    "Parameter": ["trM0", "detM0", "trM1", "detM1", "S0"],
    "From GT Instrument": [GT_instrument.M0.trace(), np.linalg.det(GT_instrument.M0), GT_instrument.M1.trace(), np.linalg.det(GT_instrument.M1), GT_instrument.M0.sum()],
    "From Empirical Data": [emp_inv[0], emp_inv[1], emp_inv[2], emp_inv[3], emp_inv[6] * 2],
})
print("\nComparison of parameters used for reconstruction:")
print(comparison_table)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Loading cached simulated data from simulated_data_405...
Override experiment info with cached version.
GT
number_of_seeds
shots_per_seed
max_string_length_L
use_same_v0_for_all_seeds
seed_for_reproduce

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.96510057, 0.00054589]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.02642910, 0.00642704]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00831841, 0.00260395]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00015192, 0.99042312]]

there are 2 solutions from empirical data.

Comparison of parameters used for reconstruction:
  Parameter  From GT Instrument  From Empirical Data
0      trM0            0.971528             0.971565
1     detM0            0.006188             0.006208
2      trM1            0.998742             0.998695
3     detM1            0.008238             0.008211
4        S0            0.998503             0.998559
chosen_sol_index = 0  # Choose which solution to analyze

reconstructed_MCM, _ = plot_gauge_transformation_effects(
    Instrument2x2(M0=sols[chosen_sol_index][0], M1=sols[chosen_sol_index][1]), 
    t_width_factor=1.1, 
    verbose=False,
    MCM_reference=[GT_instrument],
    resolution=1e5
    )

print("Original Instrument:")
GT_instrument.reveal()

print("\nReconstructed Instrument from invariants (empirical data):")
if reconstructed_MCM is not None:
    reconstructed_MCM.reveal()
1
2
3
================================================================================
Valid Intervals for Derived Quantities (Readout & Back-action Errors)
================================================================================
Quantity Min Max Width Value Ref1
0 prep 0 meas 1 0.00829104 0.01103694 2.74591e-03 (9.664 ± 1.373)e-3 IN (-1.19366e-03)
1 prep 1 meas 0 0.00684990 0.00959581 2.74591e-03 (8.223 ± 1.373)e-3 IN (-1.24993e-03)
2 prep 0 excite to 1 0.02658707 0.02665277 6.57042e-05 (2.662 ± 0.003)e-2 OUT below (-3.88947e-05)
3 prep 1 decay to 0 0.00308699 0.00315270 6.57042e-05 (3.120 ± 0.033)e-3 IN (+2.99946e-05)
1
2
3
================================================================================
Valid Intervals for Matrix Entries
================================================================================
Entry Min Max Width Value Ref1
0 M^0[0,0] = p_0^(0,0) 0.96505619 0.96512190 6.57042e-05 (9.651 ± 0.000)e-1 IN (+1.15256e-05)
1 M^0[0,1] = p_1^(0,0) 0.00040679 0.00308699 2.68020e-03 (1.747 ± 1.340)e-3 IN (-1.20100e-03)
2 M^0[1,0] = p_0^(0,1) 0.02390687 0.02658707 2.68020e-03 (2.525 ± 0.134)e-2 IN (+1.18214e-03)
3 M^0[1,1] = p_1^(0,1) 0.00644311 0.00650882 6.57042e-05 (6.476 ± 0.033)e-3 OUT below (-4.89266e-05)
4 M^1[0,0] = p_0^(1,0) 0.00829104 0.00829296 1.91931e-06 (8.292 ± 0.001)e-3 OUT above (+2.64095e-05)
5 M^1[0,1] = p_1^(1,0) 0.00000000 0.00274591 2.74591e-03 (1.373 ± 1.373)e-3 IN (+1.23100e-03)
6 M^1[1,0] = p_0^(1,1) -0.00000000 0.00274591 2.74591e-03 (1.373 ± 1.373)e-3 IN (-1.22103e-03)
7 M^1[1,1] = p_1^(1,1) 0.99040227 0.99040419 1.91931e-06 (9.904 ± 0.000)e-1 OUT above (+1.98916e-05)
1
2
3
4
5
6
7
8
9
Note: Reference comparison format:
  'IN (±X.XXXe±YY)' - value is within range, absolute discrepancy from center
  'OUT above/below (±X.XXXe±YY)' - value is outside range, absolute discrepancy from center
================================================================================



/var/folders/dk/_dd6ng8n3yv_crq6vrdt8qzc0000gn/T/ipykernel_57129/1006529938.py:551: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  plt.tight_layout()

svg

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Original Instrument:

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.96510057, 0.00054589]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.02642910, 0.00642704]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00831841, 0.00260395]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00015192, 0.99042312]]

Reconstructed Instrument from invariants (empirical data):

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.96508904, 0.00174689]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.02524697, 0.00647597]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00829200, 0.00137295]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00137295, 0.99040323]]
regions = allowed_t_regions_for_list([sols[0][0], sols[0][1]], tol=1e-24)
regions
1
2
[(np.float64(0.14868186307103115), np.float64(0.15064089805685127)),
 (np.float64(0.8493591019431488), np.float64(0.8513181369289688))]
# this is demo of if we start from the "cheated" solution using true GT directly, so you will see that the RMSE is very small and happens at t=0

GT_instrument = Instrument2x2(M0=experiment_info["GT"].P0, M1=experiment_info["GT"].P1)
reconstructed_MCM, _ = plot_gauge_transformation_effects(
    GT_instrument,
    t_width_factor=1.1, 
    verbose=False,
    MCM_reference=[GT_instrument],
    resolution=1e5
    )

print("Original Instrument:")
GT_instrument.reveal()

print("Reconstructed Instrument from invariants (best match to original):")
if reconstructed_MCM is not None:
    reconstructed_MCM.reveal()
1
2
3
================================================================================
Valid Intervals for Derived Quantities (Readout & Back-action Errors)
================================================================================
Quantity Min Max Width Value Ref1
0 prep 0 meas 1 0.00831800 0.01107388 2.75587e-03 (9.696 ± 1.378)e-3 IN (-1.22561e-03)
1 prep 1 meas 0 0.00682060 0.00957648 2.75587e-03 (8.199 ± 1.378)e-3 IN (-1.22561e-03)
2 prep 0 excite to 1 0.02657740 0.02664332 6.59146e-05 (2.661 ± 0.003)e-2 IN (-2.93333e-05)
3 prep 1 decay to 0 0.00308755 0.00315346 6.59146e-05 (3.121 ± 0.033)e-3 IN (+2.93333e-05)
1
2
3
================================================================================
Valid Intervals for Matrix Entries
================================================================================
Entry Min Max Width Value Ref1
0 M^0[0,0] = p_0^(0,0) 0.96503868 0.96510460 6.59146e-05 (9.651 ± 0.000)e-1 IN (+2.89305e-05)
1 M^0[0,1] = p_1^(0,0) 0.00039759 0.00308755 2.68996e-03 (1.743 ± 1.345)e-3 IN (-1.19668e-03)
2 M^0[1,0] = p_0^(0,1) 0.02388744 0.02657740 2.68996e-03 (2.523 ± 0.134)e-2 IN (+1.19668e-03)
3 M^0[1,1] = p_1^(0,1) 0.00642301 0.00648893 6.59146e-05 (6.456 ± 0.033)e-3 IN (-2.89305e-05)
4 M^1[0,0] = p_0^(1,0) 0.00831800 0.00831994 1.93330e-06 (8.319 ± 0.001)e-3 IN (-5.63847e-07)
5 M^1[0,1] = p_1^(1,0) 0.00000000 0.00275587 2.75587e-03 (1.378 ± 1.378)e-3 IN (+1.22601e-03)
6 M^1[1,0] = p_0^(1,1) 0.00000000 0.00275587 2.75587e-03 (1.378 ± 1.378)e-3 IN (-1.22601e-03)
7 M^1[1,1] = p_1^(1,1) 0.99042159 0.99042352 1.93330e-06 (9.904 ± 0.000)e-1 IN (+5.63847e-07)
1
2
3
4
5
6
7
8
9
Note: Reference comparison format:
  'IN (±X.XXXe±YY)' - value is within range, absolute discrepancy from center
  'OUT above/below (±X.XXXe±YY)' - value is outside range, absolute discrepancy from center
================================================================================



/var/folders/dk/_dd6ng8n3yv_crq6vrdt8qzc0000gn/T/ipykernel_57129/1006529938.py:551: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  plt.tight_layout()

svg

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
Original Instrument:

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.96510057, 0.00054589]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.02642910, 0.00642704]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00831841, 0.00260395]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00015192, 0.99042312]]
Reconstructed Instrument from invariants (best match to original):

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.96507164, 0.00174257]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.02523242, 0.00645597]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00831897, 0.00137794]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00137794, 0.99042256]]
import sympy as sp

# Define symbolic variables
a, b, c, d, t = sp.symbols('a b c d t', real=True)

# Define the original matrix M
M = sp.Matrix([[a, b], [c, d]])

# Define the gauge transformation matrix R(t)
R = sp.Matrix([[1-t, t], [t, 1-t]])

# Compute R^(-1)
R_inv = R.inv()

# Compute the gauge-transformed matrix M' = R^(-1) @ M @ R
M_prime = R_inv @ M @ R

# Simplify the result
M_prime_simplified = sp.simplify(M_prime)

print("Original matrix M:")
display(M)

print("\nGauge transformation matrix R(t):")
display(R)

print("\nInverse R^(-1):")
display(sp.simplify(R_inv))

print("\nGauge-transformed matrix M' = R^(-1) M R (before simplification):")
display(M_prime)

print("\nGauge-transformed matrix M' = R^(-1) M R (after simplification):")
display(M_prime_simplified)

# Let's also expand and collect terms for each element
print("\nElement-wise simplified expressions:")
for i in range(2):
    for j in range(2):
        element = sp.simplify(M_prime_simplified[i, j])
        element_expanded = sp.expand(element)
        print(f"M'[{i},{j}] = {element_expanded}")
1
Original matrix M:

\(\displaystyle \left[\begin{matrix}a & b\\c & d\end{matrix}\right]\)

1
Gauge transformation matrix R(t):

\(\displaystyle \left[\begin{matrix}1 - t & t\\t & 1 - t\end{matrix}\right]\)

1
Inverse R^(-1):

\(\displaystyle \left[\begin{matrix}\frac{t - 1}{2 t - 1} & \frac{t}{2 t - 1}\\\frac{t}{2 t - 1} & \frac{t - 1}{2 t - 1}\end{matrix}\right]\)

1
Gauge-transformed matrix M' = R^(-1) M R (before simplification):

\(\displaystyle \left[\begin{matrix}t \left(\frac{b \left(t - 1\right)}{2 t - 1} + \frac{d t}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{a \left(t - 1\right)}{2 t - 1} + \frac{c t}{2 t - 1}\right) & t \left(\frac{a \left(t - 1\right)}{2 t - 1} + \frac{c t}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{b \left(t - 1\right)}{2 t - 1} + \frac{d t}{2 t - 1}\right)\\t \left(\frac{b t}{2 t - 1} + \frac{d \left(t - 1\right)}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{a t}{2 t - 1} + \frac{c \left(t - 1\right)}{2 t - 1}\right) & t \left(\frac{a t}{2 t - 1} + \frac{c \left(t - 1\right)}{2 t - 1}\right) + \left(1 - t\right) \left(\frac{b t}{2 t - 1} + \frac{d \left(t - 1\right)}{2 t - 1}\right)\end{matrix}\right]\)

1
Gauge-transformed matrix M' = R^(-1) M R (after simplification):

\(\displaystyle \left[\begin{matrix}\frac{t \left(b \left(t - 1\right) + d t\right) - \left(t - 1\right) \left(a \left(t - 1\right) + c t\right)}{2 t - 1} & \frac{t \left(a \left(t - 1\right) + c t\right) - \left(t - 1\right) \left(b \left(t - 1\right) + d t\right)}{2 t - 1}\\\frac{t \left(b t + d \left(t - 1\right)\right) - \left(t - 1\right) \left(a t + c \left(t - 1\right)\right)}{2 t - 1} & \frac{t \left(a t + c \left(t - 1\right)\right) - \left(t - 1\right) \left(b t + d \left(t - 1\right)\right)}{2 t - 1}\end{matrix}\right]\)

1
2
3
4
5
Element-wise simplified expressions:
M'[0,0] = -a*t**2/(2*t - 1) + 2*a*t/(2*t - 1) - a/(2*t - 1) + b*t**2/(2*t - 1) - b*t/(2*t - 1) - c*t**2/(2*t - 1) + c*t/(2*t - 1) + d*t**2/(2*t - 1)
M'[0,1] = a*t**2/(2*t - 1) - a*t/(2*t - 1) - b*t**2/(2*t - 1) + 2*b*t/(2*t - 1) - b/(2*t - 1) + c*t**2/(2*t - 1) - d*t**2/(2*t - 1) + d*t/(2*t - 1)
M'[1,0] = -a*t**2/(2*t - 1) + a*t/(2*t - 1) + b*t**2/(2*t - 1) - c*t**2/(2*t - 1) + 2*c*t/(2*t - 1) - c/(2*t - 1) + d*t**2/(2*t - 1) - d*t/(2*t - 1)
M'[1,1] = a*t**2/(2*t - 1) - b*t**2/(2*t - 1) + b*t/(2*t - 1) + c*t**2/(2*t - 1) - c*t/(2*t - 1) - d*t**2/(2*t - 1) + 2*d*t/(2*t - 1) - d/(2*t - 1)
# Express the numerators as polynomials in t of degree 2
print("="*70)
print("Numerators of M0' elements as polynomials in t (degree 2):")
print("="*70)

for i in range(2):
    for j in range(2):
        element = simplified_m0_prime[i, j]

        # Get numerator
        numerator = sp.numer(element)

        # Expand and collect terms by powers of t
        numerator_expanded = sp.expand(numerator)
        numerator_poly = sp.collect(numerator_expanded, t)

        # Convert to Poly object to extract coefficients
        poly = sp.Poly(numerator_poly, t)
        coeffs = poly.all_coeffs()

        # Ensure we have 3 coefficients (degree 2)
        while len(coeffs) < 3:
            coeffs.append(sp.sympify(0))

        print(f"\nM0'[{i},{j}] numerator:")
        print(f"  Polynomial: ({coeffs[0]})*t^2 + ({coeffs[1]})*t + ({coeffs[2]})")
        print(f"  Expanded form: {numerator_poly}")

print("\n" + "="*70)
print("Numerators of M1' elements as polynomials in t (degree 2):")
print("="*70)

for i in range(2):
    for j in range(2):
        element = simplified_m1_prime[i, j]

        # Get numerator
        numerator = sp.numer(element)

        # Expand and collect terms by powers of t
        numerator_expanded = sp.expand(numerator)
        numerator_poly = sp.collect(numerator_expanded, t)

        # Convert to Poly object to extract coefficients
        poly = sp.Poly(numerator_poly, t)
        coeffs = poly.all_coeffs()

        # Ensure we have 3 coefficients (degree 2)
        while len(coeffs) < 3:
            coeffs.append(sp.sympify(0))

        print(f"\nM1'[{i},{j}] numerator:")
        print(f"  Polynomial: ({coeffs[0]})*t^2 + ({coeffs[1]})*t + ({coeffs[2]})")
        print(f"  Expanded form: {numerator_poly}")

print("\n" + "="*70)
print("Note: All elements have the common denominator (2*t - 1)")
print("="*70)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
======================================================================
Numerators of M0' elements as polynomials in t (degree 2):
======================================================================

M0'[0,0] numerator:
  Polynomial: (-a - b + e + f)*t^2 + (2*a + b - e)*t + (-a)
  Expanded form: -a + t**2*(-a - b + e + f) + t*(2*a + b - e)

M0'[0,1] numerator:
  Polynomial: (a + b - e - f)*t^2 + (-a + 2*e + f)*t + (-e)
  Expanded form: -e + t**2*(a + b - e - f) + t*(-a + 2*e + f)

M0'[1,0] numerator:
  Polynomial: (-a - b + e + f)*t^2 + (a + 2*b - f)*t + (-b)
  Expanded form: -b + t**2*(-a - b + e + f) + t*(a + 2*b - f)

M0'[1,1] numerator:
  Polynomial: (a + b - e - f)*t^2 + (-b + e + 2*f)*t + (-f)
  Expanded form: -f + t**2*(a + b - e - f) + t*(-b + e + 2*f)

======================================================================
Numerators of M1' elements as polynomials in t (degree 2):
======================================================================

M1'[0,0] numerator:
  Polynomial: (a + b - e - f)*t^2 + (-a - b + c - g + 1)*t + (-c)
  Expanded form: -c + t**2*(a + b - e - f) + t*(-a - b + c - g + 1)

M1'[0,1] numerator:
  Polynomial: (-a - b + e + f)*t^2 + (-c - e - f + g + 1)*t + (-g)
  Expanded form: -g + t**2*(-a - b + e + f) + t*(-c - e - f + g + 1)

M1'[1,0] numerator:
  Polynomial: (a + b - e - f)*t^2 + (-2*a - 2*b - c + e + f + g + 1)*t + (a + b + c - 1)
  Expanded form: a + b + c + t**2*(a + b - e - f) + t*(-2*a - 2*b - c + e + f + g + 1) - 1

M1'[1,1] numerator:
  Polynomial: (-a - b + e + f)*t^2 + (a + b + c - 2*e - 2*f - g + 1)*t + (e + f + g - 1)
  Expanded form: e + f + g + t**2*(-a - b + e + f) + t*(a + b + c - 2*e - 2*f - g + 1) - 1

======================================================================
Note: All elements have the common denominator (2*t - 1)
======================================================================
# Extract numerators and denominators for each element of M'
print("Numerators of M'_simplified elements as polynomials in t:\n")
print("="*70)

for i in range(2):
    for j in range(2):
        element = simplified_m0_prime[i, j]

        # Get numerator and denominator
        numerator = sp.numer(element)
        denominator = sp.denom(element)

        # Expand and collect terms by powers of t
        numerator_expanded = sp.expand(numerator)
        numerator_poly = sp.collect(numerator_expanded, t)

        print(f"\nM0'[{i},{j}]:")
        print(f"  Numerator: {numerator_poly}")
        print(f"  Denominator: {denominator}")

        # Extract coefficients
        poly = sp.Poly(numerator_poly, t)
        coeffs = poly.all_coeffs()
        print(f"  Polynomial form: ", end="")
        for idx, coeff in enumerate(coeffs):
            power = len(coeffs) - 1 - idx
            if power == 0:
                print(f"({coeff})", end="")
            elif power == 1:
                print(f"({coeff})*t + ", end="")
            else:
                print(f"({coeff})*t^{power} + ", end="")
        print()

print("\n" + "="*70)
print("\nM1 matrix elements:\n")
print("="*70)

for i in range(2):
    for j in range(2):
        element = simplified_m1_prime[i, j]

        # Get numerator and denominator
        numerator = sp.numer(element)
        denominator = sp.denom(element)

        # Expand and collect terms by powers of t
        numerator_expanded = sp.expand(numerator)
        numerator_poly = sp.collect(numerator_expanded, t)

        print(f"\nM1'[{i},{j}]:")
        print(f"  Numerator: {numerator_poly}")
        print(f"  Denominator: {denominator}")

        # Extract coefficients
        poly = sp.Poly(numerator_poly, t)
        coeffs = poly.all_coeffs()
        print(f"  Polynomial form: ", end="")
        for idx, coeff in enumerate(coeffs):
            power = len(coeffs) - 1 - idx
            if power == 0:
                print(f"({coeff})", end="")
            elif power == 1:
                print(f"({coeff})*t + ", end="")
            else:
                print(f"({coeff})*t^{power} + ", end="")
        print()

# Summary: All elements have the common denominator (2*t - 1)
print("\n" + "="*70)
print("Note: All elements share the common denominator (2*t - 1)")
print("="*70)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
Numerators of M'_simplified elements as polynomials in t:

======================================================================

M0'[0,0]:
  Numerator: -a + t**2*(-a - b + e + f) + t*(2*a + b - e)
  Denominator: 2*t - 1
  Polynomial form: (-a - b + e + f)*t^2 + (2*a + b - e)*t + (-a)

M0'[0,1]:
  Numerator: -e + t**2*(a + b - e - f) + t*(-a + 2*e + f)
  Denominator: 2*t - 1
  Polynomial form: (a + b - e - f)*t^2 + (-a + 2*e + f)*t + (-e)

M0'[1,0]:
  Numerator: -b + t**2*(-a - b + e + f) + t*(a + 2*b - f)
  Denominator: 2*t - 1
  Polynomial form: (-a - b + e + f)*t^2 + (a + 2*b - f)*t + (-b)

M0'[1,1]:
  Numerator: -f + t**2*(a + b - e - f) + t*(-b + e + 2*f)
  Denominator: 2*t - 1
  Polynomial form: (a + b - e - f)*t^2 + (-b + e + 2*f)*t + (-f)

======================================================================

M1 matrix elements:

======================================================================

M1'[0,0]:
  Numerator: -c + t**2*(a + b - e - f) + t*(-a - b + c - g + 1)
  Denominator: 2*t - 1
  Polynomial form: (a + b - e - f)*t^2 + (-a - b + c - g + 1)*t + (-c)

M1'[0,1]:
  Numerator: -g + t**2*(-a - b + e + f) + t*(-c - e - f + g + 1)
  Denominator: 2*t - 1
  Polynomial form: (-a - b + e + f)*t^2 + (-c - e - f + g + 1)*t + (-g)

M1'[1,0]:
  Numerator: a + b + c + t**2*(a + b - e - f) + t*(-2*a - 2*b - c + e + f + g + 1) - 1
  Denominator: 2*t - 1
  Polynomial form: (a + b - e - f)*t^2 + (-2*a - 2*b - c + e + f + g + 1)*t + (a + b + c - 1)

M1'[1,1]:
  Numerator: e + f + g + t**2*(-a - b + e + f) + t*(a + b + c - 2*e - 2*f - g + 1) - 1
  Denominator: 2*t - 1
  Polynomial form: (-a - b + e + f)*t^2 + (a + b + c - 2*e - 2*f - g + 1)*t + (e + f + g - 1)

======================================================================
Note: All elements share the common denominator (2*t - 1)
======================================================================

🖥️ Qiskit Experiments

Utilities for Qiskit Experiments

1
2
3
from qiskit import __version__

print(__version__)
1
2.2.3
from qiskit_ibm_runtime import QiskitRuntimeService

# Save credentials locally (can be run once)
# QiskitRuntimeService.save_account(
#     token="E3zZk7CbDNC3EeiR8Cpr-UcI4gKPOPHo425CRKtJIEK6", # Use the 44-character API_KEY you created and saved from the IBM Quantum Platform Home dashboard
#     instance="crn:v1:bluemix:public:quantum-computing:us-east:a/ed5d7d2fb3b249c6baea6864058116cb:e46b4031-d4ee-48a6-88cd-1ae20a49ec21::", # Optional
# )

# Run every time you need the service
service = QiskitRuntimeService()

data = []
for backend in service.backends():
    config = backend.configuration()
    if "simulator" in config.backend_name:
        continue
    data.append({
        "Backend": config.backend_name,
        "Processor Type": config.processor_type,
        "Supported Instructions (measure_2 is mid-circ meas.)": ", ".join(config.supported_instructions)
    })

df_backends = pd.DataFrame(data)

# Set pandas option to display full column width
pd.set_option('display.max_colwidth', None)

display(df_backends)
Backend Processor Type Supported Instructions (measure_2 is mid-circ meas.)
0 ibm_pittsburgh {'family': 'Heron', 'revision': '3'} cz, id, delay, measure, measure_2, reset, rz, sx, x, if_else
1 ibm_fez {'family': 'Heron', 'revision': '2'} cz, id, delay, measure, reset, rz, sx, x, if_else
2 ibm_marrakesh {'family': 'Heron', 'revision': '2'} cz, id, delay, measure, reset, rz, sx, x, if_else
3 ibm_kingston {'family': 'Heron', 'revision': '2'} cz, id, delay, measure, measure_2, reset, rz, sx, x, if_else
4 ibm_torino {'family': 'Heron', 'revision': '1'} cz, id, delay, measure, reset, rz, sx, x, if_else

Check Usage (quota) left

1
2
3
4
5
6
7
8
# print("Available to run on:")
# display(service.backends())

# print("\n✅ But the ones equipped with mid-circuit measurement are:")
# display(service.backends(filters=lambda b: "measure_2" in b.supported_instructions))

print("Service Usage:")
display(service.usage())
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
Service Usage:



{'instance_id': 'crn:v1:bluemix:public:quantum-computing:us-east:a/ed5d7d2fb3b249c6baea6864058116cb:e46b4031-d4ee-48a6-88cd-1ae20a49ec21::',
 'plan_id': '7f666d17-7893-47d8-bf9d-2b2389fc4dfc',
 'usage_consumed_seconds': 16654,
 'usage_period': {'start_time': '2025-11-11T07:53:34.526Z',
  'end_time': '2025-12-09T07:53:34.526Z'},
 'usage_allocation_seconds': 48000,
 'usage_limit_reached': False,
 'usage_remaining_seconds': 31346}
from qiskit import QuantumCircuit
from qiskit import transpile
from qiskit.visualization import plot_histogram
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit
from qiskit_ibm_runtime.circuit import MidCircuitMeasure
from qiskit.circuit import Measure
from qiskit.transpiler import generate_preset_pass_manager
from qiskit_ibm_runtime import SamplerV2 as Sampler
from qiskit_ibm_runtime.fake_provider import FakeManilaV2, FakeFez
# from qiskit_aer import AerSimulator
# from qiskit_aer.noise import (
#     NoiseModel,
#     QuantumError,
#     ReadoutError,
#     depolarizing_error,
#     pauli_error,
#     thermal_relaxation_error,
# )
from qiskit.quantum_info import Clifford, random_clifford
from qiskit.synthesis import OneQubitEulerDecomposer
import copy

# === Define registers and circuit ===

def create_list_of_circuits(max_word_length, num_qubits=2, visualize_first=4):

    circuits_for_local_test = []
    circuits_for_QPU = []
    qreg_q = QuantumRegister(num_qubits, 'q')
    # Create one classical register per qubit, each of size max_word_length
    cregs = [ClassicalRegister(max_word_length, f'c{i}') for i in range(num_qubits)]

    # Iterate over the 4 Pauli basis preparations. 
    # We create 4 circuits total, applying the same prep to all qubits in parallel.
    for gate_name in ['I', 'X', 'Y', 'Z']:

        # Create fresh circuits including all registers
        qc_local = QuantumCircuit(qreg_q, *cregs)
        qc_qpu = QuantumCircuit(qreg_q, *cregs)

        # Iterate over each qubit to apply the gate and measurements in parallel
        for q in range(num_qubits):
            # Apply the gate to the current qubit 'q'
            if gate_name == 'X':
                qc_local.x(q)
                qc_qpu.x(q)
            elif gate_name == 'Y':
                qc_local.y(q)
                qc_qpu.y(q)
            elif gate_name == 'Z':
                qc_local.z(q)
                qc_qpu.z(q)
            # 'I' is identity

            # Add measurements to the current qubit 'q' into its corresponding classical register
            for i in range(max_word_length-1, -1, -1):
                # Local test uses standard Measure
                qc_local.append(Measure(), [q], [cregs[q][i]])
                # QPU uses MidCircuitMeasure
                qc_qpu.append(MidCircuitMeasure(), [q], [cregs[q][i]])

        circuits_for_local_test.append(qc_local)
        circuits_for_QPU.append(qc_qpu)

    # === Visualize the circuit ===
    print(f"Visualizing circuit for local test (first {visualize_first}):")
    for circ in circuits_for_local_test[:visualize_first]:
        circ.draw('mpl', style="iqp")

    print(f"Visualizing circuit for QPU (first {visualize_first}):")
    for circ in circuits_for_QPU[:visualize_first]:
        circ.draw('mpl', style="iqp")

    return circuits_for_local_test, circuits_for_QPU




def create_list_of_circuits_v2(max_word_length):  # refer to https://quantum.cloud.ibm.com/docs/en/api/qiskit/qiskit.synthesis.OneQubitEulerDecomposer
    circuits_for_local_test = []
    circuits_for_QPU = []
    qreg_q = QuantumRegister(1, 'q')
    creg_c = ClassicalRegister(max_word_length, 'c')

    # === 1. Generate the 24 Single-Qubit Clifford Unitaries ===
    # "Full twirling" requires averaging over the full Clifford group (size 24).
    # We generate them by sampling until we have all 24 unique elements.
    clifford_set = set()
    clifford_ops = []

    while len(clifford_ops) < 24:
        c = random_clifford(1)
        # We use the matrix representation to check for uniqueness and decomposition
        c_matrix = tuple(c.to_matrix().flatten()) 
        # Note: Phase differences matter for gates, but often we just need the group 
        # modulo phase for twirling. Here we keep distinct operations.
        # To be safe and simple, we collect unique Cliffords.
        if str(c) not in clifford_set:
            clifford_set.add(str(c))
            clifford_ops.append(c)

    # === 2. Decompose into Basis Gates ===
    # Your colleague recommended OneQubitEulerDecomposer.
    # We use 'ZSX' basis (common for IBM hardware) to decompose the Cliffords
    # into sequences of Rz and SX gates.
    decomposer = OneQubitEulerDecomposer(basis='ZSX')

    # Create the base circuits
    twirling_circuits = []
    for c_op in clifford_ops:
        # Convert the Clifford operator to a unitary matrix
        unitary = c_op.to_matrix()

        # Synthesize the circuit from the unitary
        # simplify=True (default) will optimize Identity to empty. 
        # If you STRICTLY need a gate for Identity, you'd need to manually handle it,
        # but for standard twirling, the idle identity is correct.
        decomposed_qc = decomposer(unitary)

        # Create a fresh circuit container
        qc = QuantumCircuit(qreg_q, creg_c)
        qc.compose(decomposed_qc, qubits=[0], inplace=True)
        twirling_circuits.append(qc)

    # === 3. Expand lists for Local/QPU usage ===
    # Extend the lists with the generated 24 circuits
    circuits_for_local_test.extend(twirling_circuits)
    circuits_for_QPU.extend(twirling_circuits)

    circuits_for_local_test = copy.deepcopy(circuits_for_local_test)
    circuits_for_QPU = copy.deepcopy(circuits_for_QPU)

    # === 4. Add Measurements ===
    for i in range(max_word_length-1, -1, -1):
        for circ in circuits_for_local_test:
            circ.append(Measure(), [0], [i])
        for circ in circuits_for_QPU:
            circ.append(MidCircuitMeasure(), [0], [i])

    # === Visualize ===
    print(f"Generated {len(circuits_for_local_test)} circuits (Clifford Group Size).")
    print("Visualizing a few sample circuits:")
    # Draw the first few distinct circuits to verify decomposition
    for k in range(min(24, len(circuits_for_local_test))):
        print(f"Circuit {k}:")
        circuits_for_local_test[k].draw('mpl', style="iqp")

    return circuits_for_local_test, circuits_for_QPU
def create_non_trivial_list_of_circuits(family_names = ['repeated-X-MCM'], max_word_length=4, num_qubits=2):
    """
    Generates a list of circuits based on specific non-trivial families of MCM sequences.

    Args:
        family_names: List of strings specifying which circuit families to generate.
                    Currently supports: 'repeated-X-MCM'.
        max_word_length: Length of the classical register (number of measurements).
        num_qubits: Number of qubits to generate circuits for.

    Returns:
        Tuple of (circuits_for_local_test, circuits_for_QPU)
    """
    circuits_for_local_test = []
    circuits_for_QPU = []

    qreg_q = QuantumRegister(num_qubits, 'q')
    # Create one classical register per qubit, each of size max_word_length
    cregs = [ClassicalRegister(max_word_length, f'c{i}') for i in range(num_qubits)]

    for family in family_names:
        if family == 'repeated-X-MCM':
            # Create ONE set of circuits for this family, applying logic to ALL qubits in parallel
            qc_local = QuantumCircuit(qreg_q, *cregs)
            qc_qpu = QuantumCircuit(qreg_q, *cregs)

            # Apply X gate then Measure, repeated max_word_length times
            # This creates a toggling sequence: |0> -> X -> |1> -> Meas(1) -> X -> |0> -> Meas(0) ...
            for i in range(max_word_length-1, -1, -1):
                # Apply X gate to ALL qubits
                for q in range(num_qubits):
                    qc_local.x(q)
                    qc_qpu.x(q)

                # Apply Measurement to ALL qubits
                for q in range(num_qubits):
                    # Local test uses standard Measure
                    qc_local.append(Measure(), [q], [cregs[q][i]])
                    # QPU uses MidCircuitMeasure
                    qc_qpu.append(MidCircuitMeasure(), [q], [cregs[q][i]])

            circuits_for_local_test.append(qc_local)
            circuits_for_QPU.append(qc_qpu)
        else:
            print(f"Warning: Family '{family}' not implemented.")

    return circuits_for_local_test, circuits_for_QPU
def plot_job_results(job_result, num_qubits, backend_name=None, is_parallel_mode=True, plot_first_n=2):
    """
    Plot histogram comparing results across multiple circuits (PUBs), grouped by classical register index.

    Args:
        job_result: PrimitiveResult object from Qiskit Runtime
        num_qubits: Number of qubits (groups) to split the results into.
        backend_name: Optional name of the backend used for execution

    Returns:
        list: A list of length num_qubits, where each element is a list of count dictionaries 
              for that classical register index.
    """
    # Initialize list of lists to hold counts for each qubit/classical register
    results_by_qubit = [[] for _ in range(num_qubits)]

    # Determine if we are processing the parallelized circuits (4 PUBs total) 
    # or the sequential/grouped circuits (multiple PUBs per qubit).
    # The updated create_list_of_circuits produces exactly 4 circuits (I, X, Y, Z) for all qubits.
    # is_parallel_mode = (len(job_result) == 4)

    if is_parallel_mode:
        # Parallel Mode: Iterate over the 4 PUBs, extract counts for ALL qubits from EACH PUB
        for i, pub_result in enumerate(job_result):
            # Each PUB contains measurements for ALL qubits in separate classical registers
            # We assume registers are named c0, c1, ... c{num_qubits-1}

            for q_idx in range(num_qubits):
                reg_name = f'c{q_idx}'

                # Access the data for this specific register
                # In SamplerV2, pub_result.data has attributes matching register names
                if hasattr(pub_result.data, reg_name):
                    data_attr = getattr(pub_result.data, reg_name)
                    counts = data_attr.get_counts()
                    results_by_qubit[q_idx].append(counts)
                else:
                    # Fallback: try to access by index if names don't match
                    keys = list(pub_result.data.keys())
                    if q_idx < len(keys):
                        counts = pub_result.data[keys[q_idx]].get_counts()
                        results_by_qubit[q_idx].append(counts)

    else:
        # Grouped Mode (Old behavior / Accuracy circuits): 
        # PUBs are grouped by qubit. First K PUBs for q0, next K for q1...
        pubs_per_qubit = len(job_result) / num_qubits

        for i, pub_result in enumerate(job_result):
            # Determine which classical register index this PUB targets
            c_reg_index = int(i // pubs_per_qubit)

            if c_reg_index < num_qubits:
                # Extract the specific classical register data
                # We assume the relevant register is at the corresponding index or named correctly
                keys = list(pub_result.data.keys())
                if c_reg_index < len(keys):
                    key = keys[c_reg_index]
                    counts = pub_result.data[key].get_counts()
                    results_by_qubit[c_reg_index].append(counts)
                    print(f"\nCounts for classical register index {c_reg_index} from PUB {i}:")
                    print(counts)

    plot_counter = 0

    # Plot histograms for each qubit group
    for q_idx, pub_counts_list in enumerate(results_by_qubit):
        if not pub_counts_list:
            continue

        all_bitstrings = set()
        for counts in pub_counts_list:
            all_bitstrings.update(counts.keys())

        # Sort bitstrings for consistent ordering
        sorted_bitstrings = sorted(all_bitstrings)

        # Prepare data for grouped bar chart
        num_pubs_in_group = len(pub_counts_list)
        colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'cyan', 'magenta']

        if num_pubs_in_group > 0:
            colors = [colors[k % len(colors)] for k in range(num_pubs_in_group)]
            bar_width = 0.8 / num_pubs_in_group
        else:
            bar_width = 0.8

        x = np.arange(len(sorted_bitstrings))

        # Create plot for this qubit group
        fig, ax = plt.subplots(figsize=(10, 6))

        # Plot bars for each PUB in this group
        for k, counts in enumerate(pub_counts_list):
            values = [counts.get(bitstring, 0) for bitstring in sorted_bitstrings]
            offset = (k - num_pubs_in_group/2 + 0.5) * bar_width
            bars = ax.bar(x + offset, values, bar_width, 
                           label=f'Circuit {k}', color=colors[k])

            # Add count labels above each bar
            for bar, val in zip(bars, values):
                if val > 0:  # Only show label if count > 0
                    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
                           str(val), ha='center', va='bottom', fontsize=8, rotation=90)

        ax.set_xlabel('Bitstring')
        ax.set_ylabel('Counts')
        ax.set_title(f'Comparison of Results for Classical Register Index {q_idx}')
        ax.set_xticks(x)
        ax.set_xticklabels(sorted_bitstrings, rotation=45)
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

        plot_counter += 1
        if plot_counter >= plot_first_n:
            break

    if backend_name:
        print(f"\nReal hardware run on {backend_name}")

    return results_by_qubit
def counts_to_probabilities(
    counts_input: Dict[str, int] | List[Dict[str, int]],
    max_len: int = None
) -> Dict[str, float]:
    """
    Convert count dictionary/dictionaries to a single probability dictionary.

    This function takes either a single dictionary or a list of dictionaries containing
    bit-string counts, aggregates all counts, and normalizes them to probabilities.
    Probabilities for shorter strings are computed by marginalizing (summing) over
    the longer strings that extend them.

    Args:
        counts_input: Either a single dict mapping bit-strings to counts, or a list of such dicts.
        max_len: Maximum length of binary strings to include. If None, inferred from input data.

    Returns:
        Dictionary mapping all binary strings (up to max_len) to their probabilities.
        Shorter strings' probabilities are computed by summing over extensions.

    Example:
        >>> counts1 = {'00': 80, '01': 20}
        >>> counts2 = {'00': 70, '10': 30}
        >>> counts_to_probabilities([counts1, counts2])
        {'0': 0.75, '1': 0.25, '00': 0.75, '01': 0.1, '10': 0.15, '11': 0.0}
    """
    # Normalize input to list of dictionaries
    if isinstance(counts_input, dict):
        counts_list = [counts_input]
    else:
        counts_list = counts_input

    if not counts_list:
        return {}

    # Determine max_len if not provided
    if max_len is None:
        max_len = max(len(bitstring) for counts_dict in counts_list 
                     for bitstring in counts_dict.keys())

    # Aggregate all counts
    aggregated_counts = {}
    total_counts = 0

    for counts_dict in counts_list:
        for bitstring, count in counts_dict.items():
            if bitstring in aggregated_counts:
                aggregated_counts[bitstring] += count
            else:
                aggregated_counts[bitstring] = count
            total_counts += count

    # Initialize probability dictionary
    prob_dict = {}

    # First, normalize the max_len strings
    for bitstring in get_all_binary_strings(max_len):
        if bitstring in aggregated_counts:
            prob_dict[bitstring] = aggregated_counts[bitstring] / total_counts
        else:
            prob_dict[bitstring] = 0.0

    # Now compute probabilities for shorter strings by marginalization
    for length in range(max_len - 1, 0, -1):
        for bitstring in get_all_binary_strings(length):
            # Sum probabilities of all extensions (bitstring + '0' and bitstring + '1')
            prob = 0.0
            for bit in ['0', '1']:
                extended = bitstring + bit
                if extended in prob_dict:
                    prob += prob_dict[extended]
            prob_dict[bitstring] = prob

    return prob_dict
def compare_instrument_to_empirical(
    instrument: Instrument2x2,
    empirical_probs: Dict[str, float],
    v0: np.ndarray = np.array([0.5, 0.5]),
    max_len: int = 4
) -> Tuple[Dict[str, float], float, pd.DataFrame]:
    """
    Calculate theoretical probabilities from an instrument and compare to empirical data.

    Args:
        instrument: Instrument2x2 object to evaluate
        empirical_probs: Dictionary of empirical probabilities from hardware
        v0: Initial state distribution (default: maximally mixed [0.5, 0.5])
        max_len: Maximum length of binary strings to consider

    Returns:
        Tuple of:
            - Dictionary of theoretical probabilities
            - Overall RMSE between theoretical and empirical
            - DataFrame with detailed comparison
    """
    # Calculate theoretical probabilities
    theoretical_probs = calculate_exact_all_string_probabilities_from_v0_and_instrument(
        inst=instrument,
        v0=v0,
        max_len=max_len
    )

    # Get common keys (strings present in both dicts)
    common_keys = sorted(set(theoretical_probs.keys()) & set(empirical_probs.keys()))

    # Calculate differences
    comparison_data = []
    squared_errors = []

    for key in common_keys:
        theo_val = theoretical_probs[key]
        emp_val = empirical_probs[key]
        diff = theo_val - emp_val
        abs_diff = abs(diff)
        rel_error = abs_diff / emp_val if emp_val != 0 else np.inf

        squared_errors.append(diff**2)

        comparison_data.append({
            'String': key,
            'Length': len(key),
            'Theoretical': f"{theo_val:.8f}",
            'Empirical (ibm)': f"{emp_val:.8f}",
            'Difference': f"{diff:+.6e}",
            'Abs Diff': f"{abs_diff:.6e}",
            'Rel Error': f"{rel_error:.4%}" if rel_error != np.inf else "inf"
        })

    # Calculate RMSE
    rmse = np.sqrt(np.mean(squared_errors))

    # Create DataFrame
    df_comparison = pd.DataFrame(comparison_data)

    return theoretical_probs, rmse, df_comparison
def display_job_details(job):
    """
    Display detailed information about a Qiskit Runtime job.

    Args:
        job: Qiskit Runtime Job object
    """
    print(f"Job ID: {job.job_id()}")
    print(f"Backend: {job.backend()}")
    print(f"Status: {job.status()}")
    print(f"Creation Time: {job.creation_date}")
    try:
        print(f"Tags: {job.tags}")
        print(f"Metadata: {job.result().metadata['execution']}")
    except Exception:
        print("Tags: Not available")
        print("Metadata: Not available")

When you initialize the Sampler, use the mode parameter to specify the mode you want it to run in. Possible values are batch, session, or backend objects for batch, session, and job execution mode, respectively. For more information, see Introduction to Qiskit Runtime execution modes. Note that Open Plan users cannot submit session jobs.

Create Jobs

should change the I and Z circuits according to the instruction in https://quantum.cloud.ibm.com/docs/en/api/qiskit/qiskit.synthesis.OneQubitEulerDecomposer

options = {
    "default_shots": int(1e6),
    # "simulator": {"seed_simulator": 42},
    }

max_word_length = 10

chosen_qubits = [
    0, 
    5, 
    10, 
    15, 
    56, 
    64, 
    69, 
    74, 
    103, 
    117
    ]

num_qubits = len(chosen_qubits)

circuits_for_local, circuits_for_QPU = create_list_of_circuits(max_word_length, num_qubits, visualize_first=0)

accuracy_circuits_local, accuracy_circuits_QPU = create_non_trivial_list_of_circuits(
    family_names = ['repeated-X-MCM'],
    max_word_length=max_word_length,
    num_qubits=num_qubits
)
1
2
Visualizing circuit for local test (first 0):
Visualizing circuit for QPU (first 0):

Run Jobs

ibm_QPU.png

# backend = service.least_busy(operational=True, simulator=False, min_num_qubits=1)
backend = service.backend(
    # "ibm_torino", # Heron r1
    # "ibm_fez", # Heron r2, do not support MCM
    "ibm_kingston", # Heron r2, supports MCM
    # "ibm_marrakesh", # Heron r2, many jobs pending
    # "ibm_pittsburgh", # Heron r3, supports MCM, but many jobs pending (so many in queue)
    )

# backend = FakeFez()  # for testing purposes only

# Transpile the circuits for execution
pm = generate_preset_pass_manager(backend=backend, optimization_level=0, initial_layout=chosen_qubits)

circuits_after_transpile = [pm.run(circuit) for circuit in circuits_for_QPU + accuracy_circuits_QPU] # change from `circuits_for_local` to `circuits_for_QPU` for QPU with MCM support
# circuits_after_transpile = [pm.run(circuit) for circuit in circuits_for_local + accuracy_circuits_local] # change from `circuits_for_local` to `circuits_for_QPU` for QPU with MCM support

# # take a look at the transpiled circuits to run on QPU
# for circ in circuits_after_transpile:
#     circ.draw('mpl', style="iqp")
# 🔴 !! Only run this cell when ready to submit to real hardware

if "you_really_wanna_run" == "nah":

    print(f"🖥️ Running on {backend.name},\non qubit(s) with index: {chosen_qubits},\na total of {len(circuits_after_transpile)} circuits,\neach ciruit with {options['default_shots']} shots.")

    # Initialize the Sampler and Run
    sampler = Sampler(mode=backend, options=options)
    job = sampler.run(circuits_after_transpile)
    job_result = job.result()

    print("done 😃😃😃.")

Retrieve jobs

jobs_list = [
    'd4ekjg8lslhc73cuq4f0', # good for demo
    'd4el3rkcdebc73ev7eh0', # demo for low fidelity MCM
    'd4f1u5h2bisc73a22400', # margin_tol = 3e-4 might be tight enough
    'd4f21892bisc73a22740', # margin_tol = 5e-5 might be tight enough
    'd4f23rolslhc73cv8fhg', # margin_tol = 
    'd4f60fh2bisc73a268ag', # margin_tol = 0.0000126 will give a single-point valid t region
    'd4f78fglslhc73cvdk8g', # margin_tol = 0, 
    'd4f7g0glslhc73cvdsk0', # margin_tol = 0, a perfect reconstruction

    'd4f8vbccdebc73evs2r0', # margin_tol = 0.0000037, Kingston qubit 105, trial 1, MEASURE_2 error from ibm: 3.418e-3
    'd4f90aulo8as739oe4qg', # margin_tol = 0.00027, Kingston qubit 105, trial 2, seems bad and inconsistent with trial 1, maybe I should let it rest

    'd4f9570lslhc73cvffeg', # margin_tol = 0.00008, forgot which qubit

    'd4fb9692bisc73a2bhd0', # 1e6 shots for each of the 4 circuits on q48, perfect reconstruction with 0 margin_tol
    'd4fcer6lo8as739ohimg', # 1e5 shots for each of the 4 circuits on q48, perfect reconstruction with 0 margin_tol (trial 1)
    'd4fchbulo8as739ohkug', # 1e5 shots for each of the 4 circuits on q48, margin_tol = 0.00014 (trial 2)

    'd4fk7kccdebc73f08fqg', # 1e5 shots for each of the 4 circuits, on qubit #4, perfect reconstruction with 0 margin_tol
    'd4ii9iiv0j9c73e2cd3g', # 1e5 shots for each of the 4 circuits, on qubit #4, perfect reconstruction with 0 margin_tol (trial 2

    # ==== Parallel 10-qubit runs ===

    'd4jkbls3tdfc73dnj9u0', # 10-parallel learning, 1e6 shots, ibm_kingston, max_word_length=6
    'd4jkdt2v0j9c73e3erf0', # 10-parallel learning, 1e6 shots, ibm_pittsburgh, max_word_length=6
    'd4jkm3l74pkc738653sg', # 10-parallel learning, 1e6 shots, ibm_kingston, max_word_length=10
    'd4jkmi574pkc738654a0', # 10-parallel learning, 1e6 shots, ibm_pittsburgh, max_word_length=10
]
i_wanna_retrieve_job_with_id = "d4jkm3l74pkc738653sg" # replace with some random below-5-digit string to skip retrieval
max_word_length = 10 # change as needed because different jobs may have different max_word_length set by me

if len(i_wanna_retrieve_job_with_id) > 5:
    service = QiskitRuntimeService(
        channel='ibm_quantum_platform',
        instance='crn:v1:bluemix:public:quantum-computing:us-east:a/ed5d7d2fb3b249c6baea6864058116cb:e46b4031-d4ee-48a6-88cd-1ae20a49ec21::'
    )
    job = service.job(i_wanna_retrieve_job_with_id)
    job_result = job.result()
    print(f"Retrieved job result for job ID: {i_wanna_retrieve_job_with_id}\n")
    display_job_details(job)
else:
    print("No job ID specified for retrieval., proceeding with existing job_result (if exists).")
1
2
3
4
5
6
7
8
Retrieved job result for job ID: d4jkm3l74pkc738653sg

Job ID: d4jkm3l74pkc738653sg
Backend: <IBMBackend('ibm_kingston')>
Status: DONE
Creation Time: 2025-11-26 12:46:38.452118-06:00
Tags: ['10-parallel learning', '1e6 shots', 'max length 10']
Metadata: {'execution_spans': ExecutionSpans([DoubleSliceSpan(<start='2025-11-26 23:52:42', stop='2025-11-26 23:59:42', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:00:09', stop='2025-11-27 00:07:08', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:07:41', stop='2025-11-27 00:14:40', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:15:41', stop='2025-11-27 00:22:40', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:23:44', stop='2025-11-27 00:30:44', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:31:15', stop='2025-11-27 00:38:14', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:38:41', stop='2025-11-27 00:45:41', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:46:36', stop='2025-11-27 00:53:40', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:54:41', stop='2025-11-27 01:01:40', size=1000000>), DoubleSliceSpan(<start='2025-11-27 01:02:35', stop='2025-11-27 01:09:34', size=1000000>)])}

Additional Tasks

  1. Obtain the framework consistency/integrity (%) -> done.

  2. use learned MCM to bound SP error

  3. apply QEM using the learned SP (includes initial SP and back action of MCM) and M error models, then compare with IBM state-of-the-art QEM performance

  4. Applications:

  5. Superdense Coding: https://quantum.cloud.ibm.com/learning/en/courses/basics-of-quantum-information/entanglement-in-action/qiskit-implementation#superdense-coding

  6. Long-range entanglement with dynamic circuits: https://quantum.cloud.ibm.com/docs/en/tutorials/long-range-entanglement?utm_source=chatgpt.com


References:

  • Blog post of MCM availability on Nov 19: https://www.ibm.com/quantum/blog/utility-scale-dynamic-circuits

0. Utilities for Additional Tasks

# for [2. Obtain SP error bounds from learned MCM]

# pauli_labels must match the order used in create_list_of_circuits
PAULI_LABELS = ("I", "X", "Y", "Z")

def extract_counts_by_qubit_and_pauli(
    job_result,
    num_qubits: int,
    pauli_labels: Tuple[str, ...] = PAULI_LABELS,
) -> Dict[Tuple[int, str], Dict[str, int]]:
    """
    From a SamplerV2 PrimitiveResult, reconstruct:
        (qubit_index, Pauli_label) -> counts over repeated-MCM bitstrings.

    Supports both:
        1. Parallel Mode (New): 4 PUBs total (I, X, Y, Z), each containing data for ALL qubits.
        2. Sequential Mode (Old): (num_qubits * 4) PUBs, grouped by qubit.

    Returns
    -------
    result_counts : dict[(int,str), dict[str,int]]
        Example key: (0, "X") -> {'000000': ..., '111000': ..., ...}
    """
    num_pubs = len(job_result)
    num_paulis = len(pauli_labels)

    result_counts: Dict[Tuple[int, str], Dict[str, int]] = {}

    # --- Case 1: Parallel Mode ---
    # If we have exactly one PUB per Pauli label (e.g., 4 PUBs), we assume parallel execution.
    if num_pubs == num_paulis:
        for i, pub in enumerate(job_result):
            P = pauli_labels[i]  # The i-th PUB corresponds to the i-th Pauli applied to ALL qubits

            # Iterate over all qubits to extract their specific counts from this single PUB
            for q in range(num_qubits):
                creg_name = f'c{q}'

                # In SamplerV2, data attributes usually match register names (c0, c1, ...)
                if hasattr(pub.data, creg_name):
                    data_attr = getattr(pub.data, creg_name)
                    counts = data_attr.get_counts()
                    result_counts[(q, P)] = dict(counts)
                else:
                    # Fallback: access by index if attribute name lookup fails
                    # We assume the registers were added in order c0, c1, ...
                    keys = list(pub.data.keys())
                    if q < len(keys):
                        key = keys[q]
                        counts = pub.data[key].get_counts()
                        result_counts[(q, P)] = dict(counts)
                    else:
                        print(f"Warning: Could not find register for qubit {q} in PUB {i} (Pauli {P})")

    # --- Case 2: Sequential Mode (Legacy) ---
    # If we have (num_qubits * num_paulis) PUBs, we assume the old sequential ordering.
    elif num_pubs == num_qubits * num_paulis:
        for i, pub in enumerate(job_result):
            q = i // num_paulis          # which qubit this circuit targets
            p_idx = i % num_paulis       # which Pauli in PAULI_LABELS
            P = pauli_labels[p_idx]

            # pub.data keys are in the order of classical registers (c0, c1, ...)
            keys = list(pub.data.keys())
            if q >= len(keys):
                raise RuntimeError(
                    f"For PUB {i}, expected at least {q+1} classical subsets, got {len(keys)}"
                )

            # In the sequential construction, all registers existed, but we targeted the q-th one
            creg_key = keys[q]
            counts = pub.data[creg_key].get_counts()

            result_counts[(q, P)] = dict(counts)

    else:
        raise ValueError(
            f"Unexpected number of PUBs: {num_pubs}. "
            f"Expected {num_paulis} (Parallel) or {num_qubits * num_paulis} (Sequential)."
        )

    return result_counts
# for [2. Obtain SP error bounds from learned MCM]

def get_prob_meas_0_from_counts(
        result_counts: Dict[Tuple[int, str], Dict[str, int]],
        qubit_index: int,
        filter_paulis: Tuple[str, ...] = ("I", "Z"),
        ) -> float:
    """
    Extracts the probability of measuring '0' for a specific qubit from the result counts.
    It aggregates counts from 'I' and 'Z' Pauli basis preparations.
    """
    # filter the result_counts to only show the specified qubit's I and Z data
    counts_list = []
    for key in result_counts:
        if key[0] == qubit_index and key[1] in filter_paulis:
            # print(f"\nCounts for qubit {key[0]} with Pauli {key[1]}:")
            # print(result_counts[key])
            counts_list.append(result_counts[key])

    # convert the filtered counting dict into prob dict by using `counts_to_probabilities()`
    filtered_ibm_prob_list = counts_to_probabilities(counts_list)

    # print(filtered_ibm_prob_list)
    prob_0 = filtered_ibm_prob_list.get('0', 0.0)

    print(f"Qubit {qubit_index}: prob of reading out 0 (noisy state prep): {prob_0:.5e}")
    print(f"Qubit {qubit_index}: prob of reading out 1 (noisy state prep): {filtered_ibm_prob_list.get('1', 0.0):.5e}")

    return prob_0
# for [2. Obtain SP error bounds from learned MCM]

def calculate_and_display_refined_intervals(
    reconstructed_MCM: Instrument2x2 | Any,
    refined_t_bound: Tuple[float, float],
    resolution: int = 1000
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Calculates and displays the min/max intervals for matrix entries and derived quantities
    of a reconstructed MCM instrument over a refined gauge parameter range.

    Args:
        reconstructed_MCM: The Instrument2x2 object at the center of the gauge (t=0).
        refined_t_bound: The valid interval (t_min, t_max) for the gauge parameter.
        resolution: Number of points to sample within the t_bound for min/max calculation.

    Returns:
        Tuple containing:
            - DataFrame for derived quantities intervals.
            - DataFrame for matrix entries intervals.
    """
    t_min, t_max = refined_t_bound

    # Handle the case where t_min == t_max (single point)
    if abs(t_max - t_min) < 1e-15:
        t_values = np.array([t_min])
    else:
        t_values = np.linspace(t_min, t_max, resolution)

    # Store the 8 entries of the transformed instrument for each value of t
    transformed_entries = []

    for t_val in t_values:
        M0_prime, M1_prime = gauge_transform_instrument_numerically(
            reconstructed_MCM.M0, reconstructed_MCM.M1, t_val
        )
        entries = np.concatenate((M0_prime.flatten(), M1_prime.flatten()))
        transformed_entries.append(entries)

    transformed_entries = np.array(transformed_entries)

    # Helper for formatting value column
    def fmt_val_err(min_v, max_v):
        c = (min_v + max_v) / 2
        h = (max_v - min_v) / 2

        ref = abs(c) if abs(c) > 0 else abs(h)
        if ref == 0:
            return "(0.000 ± 0.000)e+0"

        exponent = int(np.floor(np.log10(ref)))
        scale = 10.0 ** (-exponent)

        c_s = c * scale
        h_s = h * scale

        return f"({c_s:.3f} ± {h_s:.3f})e{exponent:+d}"

    # ===== 1. Matrix Entries =====
    entry_labels = [
        "M^0[0,0] = p_0^(0,0)",
        "M^0[0,1] = p_1^(0,0)",
        "M^0[1,0] = p_0^(0,1)",
        "M^0[1,1] = p_1^(0,1)",
        "M^1[0,0] = p_0^(1,0)",
        "M^1[0,1] = p_1^(1,0)",
        "M^1[1,0] = p_0^(1,1)",
        "M^1[1,1] = p_1^(1,1)"
    ]

    entry_ranges_data = []

    for entry_idx, entry_label in enumerate(entry_labels):
        block_values = transformed_entries[:, entry_idx]
        min_val = np.min(block_values)
        max_val = np.max(block_values)
        width = max_val - min_val

        entry_ranges_data.append({
            'Entry': entry_label,
            'Min': f"{min_val:.8f}",
            'Max': f"{max_val:.8f}",
            'Width': f"{width:.5e}",
            'Value': fmt_val_err(min_val, max_val)
        })

    # ===== 2. Derived Quantities =====
    prep0_meas1 = transformed_entries[:, 4] + transformed_entries[:, 6]
    prep1_meas0 = transformed_entries[:, 1] + transformed_entries[:, 3]
    prep0_excite = transformed_entries[:, 2] + transformed_entries[:, 6]
    prep1_decay = transformed_entries[:, 1] + transformed_entries[:, 5]

    derived_quantities = {
        "prep 0 meas 1": prep0_meas1,
        "prep 1 meas 0": prep1_meas0,
        "prep 0 excite to 1": prep0_excite,
        "prep 1 decay to 0": prep1_decay
    }

    quantity_ranges_data = []
    for quantity_name, block_values in derived_quantities.items():
        min_val = np.min(block_values)
        max_val = np.max(block_values)
        width = max_val - min_val

        quantity_ranges_data.append({
            'Quantity': quantity_name,
            'Min': f"{min_val:.8f}",
            'Max': f"{max_val:.8f}",
            'Width': f"{width:.5e}",
            'Value': fmt_val_err(min_val, max_val)
        })

    # ===== Display DataFrames =====
    print("="*80)
    print(f"Refined Intervals over t ∈ [{t_min:.6e}, {t_max:.6e}]")
    print("="*80)

    print("\nDerived Quantities (Readout & Back-action Errors):")
    df_quantities = pd.DataFrame(quantity_ranges_data)
    display(df_quantities)

    print("\nMatrix Entries:")
    df_entries = pd.DataFrame(entry_ranges_data)
    display(df_entries)
    print("="*80 + "\n")

    return df_quantities, df_entries
# for [4. Accuracy test]

def predict_and_compare_mcm_readouts(
    epsilon_bound: Tuple[float, float],
    bounded_quantities: pd.DataFrame,
    bounded_entries: pd.DataFrame,
    measure_2_error: float,
    pauli_x_error: float,
    ibm_verify_prob_list: Dict[str, float],
    max_word_length: int = 6
):
    """
    Predicts the probability of measuring '0' at each step of a repeated X-MCM circuit
    using three models (Learned, IBM Simplified, Ideal) and compares them with
    empirical QPU data.

    The circuit consists of repeated cycles of: X gate -> Mid-Circuit Measurement.

    Args:
        epsilon_bound: Tuple (min, max) for the state preparation error epsilon.
        bounded_quantities: DataFrame containing min/max values for MCM properties
                        (prep 0 meas 1, prep 1 meas 0, prep 0 excite to 1, prep 1 decay to 0).
        bounded_entries: DataFrame containing min/max values for MCM entries (unused).
        measure_2_error: The single scalar error rate provided by IBM for MCM readout.
        pauli_x_error: The single scalar error rate for X gate (depolarizing-like).
        ibm_verify_prob_list: Dictionary of empirical probabilities from the QPU.
        max_word_length: Number of MCMs in the sequence.
    """

    # --- Helper to extract interval from DataFrame ---
    def get_interval(name):
        row = bounded_quantities[bounded_quantities['Quantity'] == name].iloc[0]
        return float(row['Min']), float(row['Max'])

    # --- 1. Extract Learned Model Parameters (Intervals) ---
    # Readout errors
    p0_m1_int = get_interval("prep 0 meas 1")
    p1_m0_int = get_interval("prep 1 meas 0")

    # Back-action (state update) probabilities
    p0_excite_int = get_interval("prep 0 excite to 1")
    p1_decay_int = get_interval("prep 1 decay to 0")

    # State prep error
    eps_int = epsilon_bound

    # --- 2. Initialize States for all 3 models ---
    # We track p(1) (probability of being in state |1>)

    # A. Learned Model (Interval arithmetic)
    # Initial state |0> noisy: v0 = [1-eps, eps] -> p(1) is eps
    current_p1_interval = eps_int

    # B. IBM Simplified Model (Scalar)
    # Initial state |0> perfect: p(1) = 0
    current_p1_ibm = 0.0

    # C. Ideal Model (Scalar)
    # Initial state |0> perfect: p(1) = 0
    current_p1_ideal = 0.0

    # Storage for predictions
    predictions = []

    # --- 3. Simulate the Circuit Step-by-Step ---
    # The circuit is: Repeat(X -> Measure)

    for step in range(1, max_word_length + 1):
        # --- Step 3a: Apply X Gate ---
        # Rule: v_after = [(1 - px) X + px I] v_before
        # p1_new = (1 - px)(1 - p1) + px * p1 = 1 - p1 - px + 2*px*p1 = (1 - px) + p1 * (2*px - 1)

        # Learned (Interval)
        # Since (2*px - 1) is negative (assuming error < 0.5), max maps to min and min maps to max.
        px = pauli_x_error
        p1_min, p1_max = current_p1_interval

        next_p1_min = (1 - px) + p1_max * (2 * px - 1)
        next_p1_max = (1 - px) + p1_min * (2 * px - 1)
        current_p1_interval = (next_p1_min, next_p1_max)

        # IBM (Scalar)
        current_p1_ibm = (1 - px) + current_p1_ibm * (2 * px - 1)

        # Ideal (Scalar) - Perfect X
        current_p1_ideal = 1.0 - current_p1_ideal

        # --- Step 3b: Calculate Readout Probability for '0' ---
        # P(r=0) = (1 - p0m1)(1 - p1) + p1m0 * p1
        #        = (1 - p0m1) + p1 * (p1m0 + p0m1 - 1)
        # Coefficient (p1m0 + p0m1 - 1) is negative.

        # Learned Model Prediction (Interval)
        # Maximize P(r=0): use min p1, min p0m1, max p1m0 (since partials: dp1<0, dp0m1<0, dp1m0>0)
        # Minimize P(r=0): use max p1, max p0m1, min p1m0

        p1_min, p1_max = current_p1_interval
        e01_min, e01_max = p0_m1_int
        e10_min, e10_max = p1_m0_int

        # Max P(r=0)
        pred_learned_max = (1 - e01_min) * (1 - p1_min) + e10_max * p1_min
        # Min P(r=0)
        pred_learned_min = (1 - e01_max) * (1 - p1_max) + e10_min * p1_max

        pred_learned = (pred_learned_min, pred_learned_max)

        # IBM Model Prediction
        # P(r=0) = (1 - m2_err)(1 - p1) + m2_err * p1
        pred_ibm = (1 - measure_2_error) * (1 - current_p1_ibm) + measure_2_error * current_p1_ibm

        # Ideal Model Prediction
        pred_ideal = 1.0 - current_p1_ideal

        # --- Step 3c: Apply Back-action (State Update) ---
        # p1_new = u01(1 - p1) + (1 - u10)p1
        #        = u01 + p1 * (1 - u10 - u01)
        # Coefficient (1 - u10 - u01) is positive.

        # Learned Model Update
        # Max p1_new: max p1, max u01, min u10
        # Min p1_new: min p1, min u01, max u10

        u01_min, u01_max = p0_excite_int
        u10_min, u10_max = p1_decay_int

        next_state_p1_max = u01_max * (1 - p1_max) + (1 - u10_min) * p1_max
        next_state_p1_min = u01_min * (1 - p1_min) + (1 - u10_max) * p1_min

        current_p1_interval = (next_state_p1_min, next_state_p1_max)

        # IBM Model Update (No back-action modeled -> Identity)
        # current_p1_ibm remains unchanged

        # Ideal Model Update (Identity)
        # current_p1_ideal remains unchanged

        # --- Step 3d: Extract Empirical Data ---
        # Sum probabilities of all strings of length `step` ending in '0'
        emp_prob_0 = 0.0
        found_keys = False
        for s, prob in ibm_verify_prob_list.items():
            if len(s) == step and s.endswith('0'):
                emp_prob_0 += prob
                found_keys = True

        if not found_keys and step == 1:
             emp_prob_0 = ibm_verify_prob_list.get('0', 0.0)

        predictions.append({
            "MCM Index": step,
            "Ideal": pred_ideal,
            "IBM Simple": pred_ibm,
            "Learned Min": pred_learned[0],
            "Learned Max": pred_learned[1],
            "Empirical": emp_prob_0
        })

    # --- 4. Visualization ---
    df_res = pd.DataFrame(predictions)

    # Calculate differences relative to Empirical
    df_res["Diff Ideal"] = df_res["Ideal"] - df_res["Empirical"]
    df_res["Diff IBM"] = df_res["IBM Simple"] - df_res["Empirical"]
    df_res["Diff Learned Min"] = df_res["Learned Min"] - df_res["Empirical"]
    df_res["Diff Learned Max"] = df_res["Learned Max"] - df_res["Empirical"]

    # --- Plot 1: Absolute Probabilities ---
    plt.figure(figsize=(12, 6))

    # Plot IBM Simple Model
    plt.plot(df_res["MCM Index"], df_res["IBM Simple"], 'b--o', label="from IBM's calibration data")

    # Plot Ideal Model
    plt.plot(df_res["MCM Index"], df_res["Ideal"], 'k:', alpha=0.75, label='Ideal')

    # Plot Empirical Data
    plt.plot(df_res["MCM Index"], df_res["Empirical"], 'r-x', linewidth=0.5, label='Empirical Data')

    # Plot Learned Interval (Shaded Region)
    plt.fill_between(df_res["MCM Index"], df_res["Learned Min"], df_res["Learned Max"], 
                    color='green', alpha=0.75, label='Our protocol (as interval)')

    plt.xlabel('MCM Index (Step in Circuit)')
    plt.ylabel('Probability of Measuring 0')
    plt.title(f'Repeated X-MCM: Absolute Probabilities (P(0))\n(State Prep Error $\epsilon \in [{epsilon_bound[0]:.1e}, {epsilon_bound[1]:.1e}]$)')
    plt.xticks(df_res["MCM Index"])
    plt.grid(True, alpha=0.3, which='both')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

    # --- Plot 2: Deviations from Empirical ---
    plt.figure(figsize=(12, 6))

    # Plot IBM Simple Model Deviation
    plt.plot(df_res["MCM Index"], df_res["Diff IBM"], 'b--o', label="from IBM's calibration data")

    # Plot Ideal Model Deviation
    plt.plot(df_res["MCM Index"], df_res["Diff Ideal"], 'k:', alpha=0.5, label='Ideal Deviation')

    # Plot Empirical Data Baseline (at y=0)
    plt.axhline(y=0, color='r', linestyle='-', linewidth=0.5, label='Empirical Baseline (0)')

    # Plot Learned Interval (Shaded Region) relative to Empirical (which is at 0)
    plt.fill_between(df_res["MCM Index"], df_res["Diff Learned Min"], df_res["Diff Learned Max"], 
                    color='green', alpha=0.5, label='Our protocol (as interval)')

    plt.xlabel('MCM Index (Step in Circuit)')
    plt.ylabel('Difference from Empirical Probability (P(0))')
    plt.title(f'Repeated X-MCM: Model Deviations from Empirical Data\n(State Prep Error $\epsilon \in [{epsilon_bound[0]:.1e}, {epsilon_bound[1]:.1e}]$)')
    plt.xticks(df_res["MCM Index"])
    plt.grid(True, alpha=0.3, which='both')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

    # Display Table
    print("\nDetailed Predictions vs Empirical:")
    # Format columns for display
    df_disp = df_res.copy()

    # Calculate deviations for table
    df_disp['Learned Mid'] = (df_disp['Learned Min'] + df_disp['Learned Max']) / 2
    df_disp['Dev IBM'] = df_disp['IBM Simple'] - df_disp['Empirical']
    df_disp['Dev Learned Mid'] = df_disp['Learned Mid'] - df_disp['Empirical']

    # Calculate ratio of absolute errors (IBM Error / Learned Error)
    # A value > 1 means IBM error is larger (Learned is better)
    df_disp['Error Ratio (IBM/Learned)'] = df_disp['Dev IBM'].abs() / df_disp['Dev Learned Mid'].abs()

    df_disp['Learned Interval'] = df_disp.apply(lambda x: f"[{x['Learned Min']:.4f}, {x['Learned Max']:.4f}]", axis=1)
    df_disp['In Interval?'] = df_disp.apply(lambda x: "✅" if x['Learned Min'] <= x['Empirical'] <= x['Learned Max'] else "❌", axis=1)

    cols = ['MCM Index', 'Ideal', 'IBM Simple', 'Learned Interval', 'Empirical', 'In Interval?', 'Dev IBM', 'Dev Learned Mid', 'Error Ratio (IBM/Learned)']
    display(df_disp[cols])

    return df_res
1
2
3
4
5
6
7
8
<>:186: SyntaxWarning: "\e" is an invalid escape sequence. Such sequences will not work in the future. Did you mean "\\e"? A raw string is also an option.
<>:211: SyntaxWarning: "\e" is an invalid escape sequence. Such sequences will not work in the future. Did you mean "\\e"? A raw string is also an option.
<>:186: SyntaxWarning: "\e" is an invalid escape sequence. Such sequences will not work in the future. Did you mean "\\e"? A raw string is also an option.
<>:211: SyntaxWarning: "\e" is an invalid escape sequence. Such sequences will not work in the future. Did you mean "\\e"? A raw string is also an option.
/var/folders/dk/_dd6ng8n3yv_crq6vrdt8qzc0000gn/T/ipykernel_57129/3470257201.py:186: SyntaxWarning: "\e" is an invalid escape sequence. Such sequences will not work in the future. Did you mean "\\e"? A raw string is also an option.
  plt.title(f'Repeated X-MCM: Absolute Probabilities (P(0))\n(State Prep Error $\epsilon \in [{epsilon_bound[0]:.1e}, {epsilon_bound[1]:.1e}]$)')
/var/folders/dk/_dd6ng8n3yv_crq6vrdt8qzc0000gn/T/ipykernel_57129/3470257201.py:211: SyntaxWarning: "\e" is an invalid escape sequence. Such sequences will not work in the future. Did you mean "\\e"? A raw string is also an option.
  plt.title(f'Repeated X-MCM: Model Deviations from Empirical Data\n(State Prep Error $\epsilon \in [{epsilon_bound[0]:.1e}, {epsilon_bound[1]:.1e}]$)')

😃⏱️ Reconstruction (choose which qubit to analyze)

# 🔔‼️ don't need to run this cell if just changing `which_qubit_to_analyze` below but not changing the job itself
# because this cell takes ~ 1min to run on average

display_job_details(job)
backend = job.backend() # switch to the backend used in the retrieved job

counting_lists = plot_job_results(job_result[:4], backend_name=backend.name, num_qubits=num_qubits, plot_first_n=1)
accuracy_verification_lists = plot_job_results(job_result[-1:], backend_name=backend.name, num_qubits=num_qubits, plot_first_n=1)

# for [2. Obtain SP error bounds from learned MCM]
result_counts = extract_counts_by_qubit_and_pauli(job_result[:-1], num_qubits=num_qubits) # have to take subset of job_result if needed to exclude accuracy verification circuits
1
2
3
4
5
6
Job ID: d4jkm3l74pkc738653sg
Backend: <IBMBackend('ibm_kingston')>
Status: DONE
Creation Time: 2025-11-26 12:46:38.452118-06:00
Tags: ['10-parallel learning', '1e6 shots', 'max length 10']
Metadata: {'execution_spans': ExecutionSpans([DoubleSliceSpan(<start='2025-11-26 23:52:42', stop='2025-11-26 23:59:42', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:00:09', stop='2025-11-27 00:07:08', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:07:41', stop='2025-11-27 00:14:40', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:15:41', stop='2025-11-27 00:22:40', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:23:44', stop='2025-11-27 00:30:44', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:31:15', stop='2025-11-27 00:38:14', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:38:41', stop='2025-11-27 00:45:41', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:46:36', stop='2025-11-27 00:53:40', size=1000000>), DoubleSliceSpan(<start='2025-11-27 00:54:41', stop='2025-11-27 01:01:40', size=1000000>), DoubleSliceSpan(<start='2025-11-27 01:02:35', stop='2025-11-27 01:09:34', size=1000000>)])}

svg

1
Real hardware run on ibm_kingston

svg

1
Real hardware run on ibm_kingston
1
2
3
4
which_qubit_to_analyze = 9  #📊 change this index to analyze different qubits from the chosen_qubits list, i.e., 0,1,2,3,...

print(f"Total number of chosen qubits: {len(chosen_qubits)}")
print(f"\nAnalyzing qubit index {which_qubit_to_analyze},\nwhich is the physical qubit on {backend.name} with index {chosen_qubits[which_qubit_to_analyze]}...")
1
2
3
4
Total number of chosen qubits: 10

Analyzing qubit index 9,
which is the physical qubit on ibm_kingston with index 117...
counting_list = counting_lists[which_qubit_to_analyze] # change index to choose which qubit's data to analyze

chosen_sol_index = 0  # Choose which solution to analyze (no real effect because 2 solutions belongs to same gauge family)
print(f"\nAnalyzing results for qubit index {chosen_qubits[which_qubit_to_analyze]}:\n") 

ibm_prob_list = counts_to_probabilities(counting_list)
# display(ibm_prob_list)

emp_inv = derived_constraints_from_empirical_probs(ibm_prob_list)
# print(emp_inv)

sols = reconstruct_instrument_from_invariants_mixed_det(
    trM0 = emp_inv[0],
    detM0 = emp_inv[1],
    trM1 = emp_inv[2],
    detM1 = emp_inv[3],
    S0 = emp_inv[6] * 2, # S0 = sum of all elements of P0 = 2 * probability of observing 0 from v0=(0.5,0.5)
    gauge_p00 = 1.0
)

print(f"\nthere are {len(sols)} solutions from empirical data,\nwe choose solution with index {chosen_sol_index}.\n")

reconstructed_MCM, t_bound = plot_gauge_transformation_effects(
    Instrument2x2(sols[chosen_sol_index][0], sols[chosen_sol_index][1]), 
    t_width_factor=1.5,
    verbose=False,
    MCM_reference=[],
    resolution=1e5,
    render_plots=False,
    margin_tol=0.000
    )

print("\nReconstructed Instrument from invariants (empirical data):")
if reconstructed_MCM is not None:
    reconstructed_MCM.reveal()
1
2
3
4
5
6
7
8
9
Analyzing results for qubit index 117:


there are 2 solutions from empirical data,
we choose solution with index 0.

================================================================================
Valid Intervals for Derived Quantities (Readout & Back-action Errors)
================================================================================
Quantity Min Max Width Value
0 prep 0 meas 1 0.00005638 0.00022670 1.70321e-04 (1.415 ± 0.852)e-4
1 prep 1 meas 0 0.00050388 0.00067420 1.70321e-04 (5.890 ± 0.852)e-4
2 prep 0 excite to 1 0.00016797 0.00017037 2.39591e-06 (1.692 ± 0.012)e-4
3 prep 1 decay to 0 0.01422474 0.01422714 2.39591e-06 (1.423 ± 0.000)e-2
1
2
3
================================================================================
Valid Intervals for Matrix Entries
================================================================================
Entry Min Max Width Value
0 M^0[0,0] = p_0^(0,0) 0.99977326 0.99977330 4.61347e-08 (9.998 ± 0.000)e-1
1 M^0[0,1] = p_1^(0,0) 0.00027067 0.00044104 1.70367e-04 (3.559 ± 0.852)e-4
2 M^0[1,0] = p_0^(0,1) -0.00000000 0.00017037 1.70367e-04 (8.518 ± 8.518)e-5
3 M^0[1,1] = p_1^(0,1) 0.00023316 0.00023321 4.61347e-08 (2.332 ± 0.000)e-4
4 M^1[0,0] = p_0^(1,0) 0.00005638 0.00005873 2.34978e-06 (5.755 ± 0.117)e-5
5 M^1[0,1] = p_1^(1,0) 0.01378610 0.01395407 1.67971e-04 (1.387 ± 0.008)e-2
6 M^1[1,0] = p_0^(1,1) -0.00000000 0.00016797 1.67971e-04 (8.399 ± 8.399)e-5
7 M^1[1,1] = p_1^(1,1) 0.98553970 0.98554205 2.34978e-06 (9.855 ± 0.000)e-1
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
================================================================================


Reconstructed Instrument from invariants (empirical data):

MCM Instrument:

M0 matrix (outcome 0):
  [[p_0^(0,0), p_1^(0,0)]] = [[0.99977328, 0.00035585]]
  [[p_0^(0,1), p_1^(0,1)]]   [[0.00008518, 0.00023318]]

M1 matrix (outcome 1):
  [[p_0^(1,0), p_1^(1,0)]] = [[0.00005755, 0.01387008]]
  [[p_0^(1,1), p_1^(1,1)]]   [[0.00008399, 0.98554088]]

1. Obtain the framework consistency/integrity (%)

  • Actually we can use the MCM reconstructed from len 3 probs -> to verify the len 4 and larger ones' probs -> get the framework consistency (%)

  • RMSE for prob strings of length =3 should be same magnitude as shot noise, and RMSE for prob strings of length >3 should be slightly increasing but still close to shot noise level.

  • RMSE will also be affected by leakage errors and single-qubit gate errors.

# Use the reconstructed MCM to compare with IBM data
if reconstructed_MCM is not None:
    theoretical_probs, rmse, df_comparison = compare_instrument_to_empirical(
        instrument=reconstructed_MCM,
        empirical_probs=ibm_prob_list,
        v0=np.array([0.5, 0.5]),
        max_len=max_word_length
    )

    print(f"Overall RMSE between reconstructed MCM and IBM data: {rmse:.6e}")
    print(f"\nNumber of strings compared: {len(df_comparison)}")
    print("\nDetailed comparison:")
    # sort df_comparison's row by first using String's length ascending, then by string lexicographically
    df_comparison = df_comparison.sort_values(by=['Length', 'String']).reset_index(drop=True)

    display(df_comparison)

    # Summary statistics by string length
    print("\n" + "="*80)
    print("RMSE by string length:")
    print("="*80)
    for length in range(1, max_word_length + 1):
        mask = df_comparison['Length'] == length
        if mask.any():
            length_diffs = [float(df_comparison[mask]['Difference'].iloc[i].replace('+', ''))
                            for i in range(mask.sum())]
            length_rmse = np.sqrt(np.mean([d**2 for d in length_diffs]))
            print(f"Length {length}: RMSE = {length_rmse:.6e} ({mask.sum()} strings)")
1
2
3
4
5
Overall RMSE between reconstructed MCM and IBM data: 1.558673e-05

Number of strings compared: 2046

Detailed comparison:
String Length Theoretical Empirical (ibm) Difference Abs Diff Rel Error
0 0 1 0.50022375 0.50022375 +0.000000e+00 0.000000e+00 0.0000%
1 1 1 0.49977625 0.49977625 -2.775558e-16 2.775558e-16 0.0000%
2 00 2 0.49999388 0.49999388 +7.256240e-09 7.256240e-09 0.0000%
3 01 2 0.00022987 0.00022987 -7.256239e-09 7.256239e-09 0.0032%
4 10 2 0.00725312 0.00725313 -7.052202e-09 7.052202e-09 0.0001%
... ... ... ... ... ... ... ...
2041 1111111011 10 0.00010480 0.00014500 -4.020295e-05 4.020295e-05 27.7262%
2042 1111111100 10 0.00641989 0.00641913 +7.690554e-07 7.690554e-07 0.0120%
2043 1111111101 10 0.00010516 0.00017112 -6.596615e-05 6.596615e-05 38.5485%
2044 1111111110 10 0.00643071 0.00646600 -3.528592e-05 3.528592e-05 0.5457%
2045 1111111111 10 0.43835699 0.43812462 +2.323663e-04 2.323663e-04 0.0530%

2046 rows × 7 columns

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
================================================================================
RMSE by string length:
================================================================================
Length 1: RMSE = 1.962616e-16 (2 strings)
Length 2: RMSE = 7.154948e-09 (4 strings)
Length 3: RMSE = 3.079771e-05 (8 strings)
Length 4: RMSE = 3.119521e-05 (16 strings)
Length 5: RMSE = 3.019554e-05 (32 strings)
Length 6: RMSE = 2.759806e-05 (64 strings)
Length 7: RMSE = 2.208088e-05 (128 strings)
Length 8: RMSE = 1.807662e-05 (256 strings)
Length 9: RMSE = 1.495366e-05 (512 strings)
Length 10: RMSE = 1.150058e-05 (1024 strings)

2. Obtain SP error bounds from learned MCM

1
2
3
4
5
prob_meas_0_val = get_prob_meas_0_from_counts(result_counts, qubit_index=which_qubit_to_analyze) 

print(f"\nAnalyzing results for qubit index {chosen_qubits[which_qubit_to_analyze]}:\n")
print(f"t_bound from gauge transformation analysis: [ {t_bound[0]:.5e}, {t_bound[1]:.5e} ]")
print(f"probability of measuring '0' from noisy state preparation: {prob_meas_0_val:.5e}")
1
2
3
4
5
6
7
Qubit 9: prob of reading out 0 (noisy state prep): 9.91684e-01
Qubit 9: prob of reading out 1 (noisy state prep): 8.31575e-03

Analyzing results for qubit index 117:

t_bound from gauge transformation analysis: [ -8.52299e-05, 8.52155e-05 ]
probability of measuring '0' from noisy state preparation: 9.91684e-01
refined_t_bound, epsilon_bound = refine_bounds_with_sp_error(
    reconstructed_MCM,
    t_bound,
    prob_meas_0_val,
    margin_tol=0.00,
    verbose=True
)

print(f"\nRefined t bound: [ {refined_t_bound[0]:.5e}, {refined_t_bound[1]:.5e} ]")
print(f"😍 Estimated SP error bound for noisy |0⟩ preparation (epsilon):\n{epsilon_bound[0]:.4e}, {epsilon_bound[1]:.4e}") 
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
--- Refine Bounds Debug ---
P(outcome=1): 0.008316
Baseline P(1|0): 0.000142, P(1|1): 0.999411
Input t_bound: [-0.000085, 0.000085]
Epsilon range on input t: [0.008096, 0.008264]
Target epsilon range: [0.000000, 1.000000]
Refined t_bound: [-0.000085, 0.000085]
Refined epsilon: [0.008096, 0.008264]
---------------------------

Refined t bound: [ -8.52299e-05, 8.52155e-05 ]
😍 Estimated SP error bound for noisy |0⟩ preparation (epsilon):
8.0964e-03, 8.2640e-03
bounded_quantities, bounded_entries = calculate_and_display_refined_intervals(reconstructed_MCM, refined_t_bound)
1
2
3
4
5
================================================================================
Refined Intervals over t ∈ [-8.522992e-05, 8.521550e-05]
================================================================================

Derived Quantities (Readout & Back-action Errors):
Quantity Min Max Width Value
0 prep 0 meas 1 0.00005637 0.00022669 1.70321e-04 (1.415 ± 0.852)e-4
1 prep 1 meas 0 0.00050387 0.00067419 1.70321e-04 (5.890 ± 0.852)e-4
2 prep 0 excite to 1 0.00016797 0.00017037 2.39591e-06 (1.692 ± 0.012)e-4
3 prep 1 decay to 0 0.01422474 0.01422714 2.39591e-06 (1.423 ± 0.000)e-2
1
Matrix Entries:
Entry Min Max Width Value
0 M^0[0,0] = p_0^(0,0) 0.99977326 0.99977331 4.61347e-08 (9.998 ± 0.000)e-1
1 M^0[0,1] = p_1^(0,0) 0.00027067 0.00044104 1.70367e-04 (3.559 ± 0.852)e-4
2 M^0[1,0] = p_0^(0,1) -0.00000000 0.00017037 1.70367e-04 (8.518 ± 8.518)e-5
3 M^0[1,1] = p_1^(0,1) 0.00023315 0.00023320 4.61347e-08 (2.332 ± 0.000)e-4
4 M^1[0,0] = p_0^(1,0) 0.00005637 0.00005872 2.34978e-06 (5.755 ± 0.117)e-5
5 M^1[0,1] = p_1^(1,0) 0.01378610 0.01395407 1.67971e-04 (1.387 ± 0.008)e-2
6 M^1[1,0] = p_0^(1,1) -0.00000000 0.00016797 1.67971e-04 (8.399 ± 8.399)e-5
7 M^1[1,1] = p_1^(1,1) 0.98553971 0.98554206 2.34978e-06 (9.855 ± 0.000)e-1
1
================================================================================

3. Learn Pauli-X error rate from learned MCM

let's first trust ibm's official calibration data. because i believe in randomized benchmarking.

4. Accuracy test

# remember to change the calibration filename if switching to a different machine)

if backend.name == 'ibm_pittsburgh':
    calibration_df = pd.read_csv("./calibration_data/ibm_pittsburgh_calibrations_2025-11-29T21_05_13Z.csv") # pittsburgh 1st version
elif backend.name == 'ibm_kingston':
    calibration_df = pd.read_csv("./calibration_data/ibm_kingston_calibrations_2025-11-26T22_59_05Z.csv") # kingston 1st version
else:
    print(f"Warning: No calibration data available for backend {backend.name}")
    # calibration_df = pd.DataFrame() # empty dataframe to avoid errors

ibm_verify_prob_list = counts_to_probabilities(accuracy_verification_lists[which_qubit_to_analyze])

pauli_x_error = calibration_df[calibration_df['Qubit'] == chosen_qubits[which_qubit_to_analyze]]['Pauli-X error'].values[0]
measure_2_error = calibration_df[calibration_df['Qubit'] == chosen_qubits[which_qubit_to_analyze]]['MEASURE_2 error'].values[0]

print(f"The Pauli-X error for qubit with index {chosen_qubits[which_qubit_to_analyze]} on backend {backend.name} is: {pauli_x_error:.5e}")
print(f"The MEASURE_2 error for qubit with index {chosen_qubits[which_qubit_to_analyze]} on backend {backend.name} is: {measure_2_error:.5e}\n")

predict_and_compare_mcm_readouts(
    # epsilon_bound = epsilon_bound,
    epsilon_bound = epsilon_bound, 
    bounded_quantities = bounded_quantities,
    bounded_entries = bounded_entries,
    pauli_x_error = pauli_x_error, # read from calibration data
    measure_2_error = measure_2_error, # read from calibration data
    ibm_verify_prob_list = ibm_verify_prob_list,
    max_word_length = max_word_length
)

print(f"Analyzed qubit index {which_qubit_to_analyze}, physical qubit on {backend.name} with index {chosen_qubits[which_qubit_to_analyze]}.")
1
2
The Pauli-X error for qubit with index 117 on backend ibm_kingston is: 2.34645e-04
The MEASURE_2 error for qubit with index 117 on backend ibm_kingston is: 3.41797e-03

svg

svg

1
Detailed Predictions vs Empirical:
MCM Index Ideal IBM Simple Learned Interval Empirical In Interval? Dev IBM Dev Learned Mid Error Ratio (IBM/Learned)
0 1 0.0 0.003651 [0.0088, 0.0092] 0.008130 -0.004479 0.000864 5.184778
1 2 1.0 0.996116 [0.9770, 0.9773] 0.978823 0.017294 -0.001687 10.248965
2 3 0.0 0.004117 [0.0232, 0.0235] 0.022620 -0.018503 0.000759 24.388292
3 4 1.0 0.995651 [0.9628, 0.9631] 0.963814 0.031837 -0.000849 37.479874
4 5 0.0 0.004582 [0.0372, 0.0375] 0.036748 -0.032166 0.000590 54.498696
5 6 1.0 0.995185 [0.9490, 0.9494] 0.949516 0.045669 -0.000305 149.951474
6 7 0.0 0.005047 [0.0507, 0.0511] 0.050234 -0.045187 0.000653 69.250996
7 8 1.0 0.994721 [0.9357, 0.9360] 0.935632 0.059089 0.000233 253.663807
8 9 0.0 0.005511 [0.0639, 0.0642] 0.063346 -0.057834 0.000690 83.853161
9 10 1.0 0.994257 [0.9227, 0.9231] 0.922078 0.072178 0.000833 86.637467
1
Analyzed qubit index 9, physical qubit on ibm_kingston with index 117.
# I wonder how will the predicted curve (green) and ibm's curve (blue) deviates more (or less) from the observed data (red) if we feed partial (or all) wrongly learned datas/parameters into

5. Store good jobs

good_jobs = [
    'd4jkbls3tdfc73dnj9u0', # 10-parallel learning, 1e6 shots, ibm_kingston, max_word_length=6
    # indexes with 0 margin_tol = [
    # 1 (meh), 
    # 2 (really good),
    # 4 (quite good),
    # 5 (bad),
    # 7 (quite good),
    # 8 (meh good but a bit strange, a bit wide)
    # 9 (quite good, though all experiment points lies outside the prediction but error reduce is ~10x)
    # ],
    'd4jkm3l74pkc738653sg', # 10-parallel learning, 1e6 shots, ibm_kingston, max_word_length=10
    # 0 (margin_tol=0.00013)
    # indexes with 0 margin_tol = [
    # 1 (bad, error reduction ~7x),
    # 2 (quite good),
    # 3 (meh, error reduction ~10x),
    # 4 (quite good, error reduction ~30x),
    # 5 (bad),
    # 7 (pretty good, similar to 2),
    # 8 (meh good and also a bit strange, a bit wide, error reduction ~10x),
    # 9 (very good, all experiment points lies outside the prediction but error reduce is ~80x)
    # ],
    'd4k76a90i6jc73deha7g' # 10-parallel learning, 1e6 shots, ibm_pittsburgh, max_word_length=10
    # indexes with 0 margin_tol = [
    # 0 (really good, though all experiment points lies outside the prediction but error reduce is ~100x
    # 1 (really bad, what's happening? ~3x)
    # ],
]
def update_qubit_error_in_json(
    qubit_index: int,
    bounded_quantities: pd.DataFrame,
    json_output_path: str = "ibm_kingston.json"
):
    """
    Updates the error analysis results for a single qubit in an existing Plotly JSON file.

    Args:
        qubit_index: The physical index of the qubit being analyzed.
        bounded_quantities: DataFrame containing 'Quantity', 'Min', and 'Max' columns.
        json_output_path: Path to the JSON file to update or create.
    """

    # 1. Initialize Data Structure
    # Structure: { category_name: { qubit_idx: {'val': center, 'err': error} } }
    # We define the 4 expected categories to ensure consistent trace ordering
    categories = [
        "prep 0 meas 1", 
        "prep 1 meas 0", 
        "prep 0 excite to 1", 
        "prep 1 decay to 0"
    ]
    data_map = {cat: {} for cat in categories}

    # 2. Load Existing Data if file exists
    if os.path.exists(json_output_path):
        try:
            with open(json_output_path, 'r') as f:
                existing_json = json.load(f)

            if 'data' in existing_json:
                for trace in existing_json['data']:
                    name = trace.get('name')
                    if name in data_map:
                        x_vals = trace.get('x', [])
                        y_vals = trace.get('y', [])
                        # Handle error bars structure in Plotly JSON
                        err_vals = trace.get('error_y', {}).get('array', [])

                        # If error_y is missing or length mismatch, fill with 0
                        if len(err_vals) != len(x_vals):
                            err_vals = [0] * len(x_vals)

                        for x, y, err in zip(x_vals, y_vals, err_vals):
                            data_map[name][int(x)] = {'val': y, 'err': err}
        except (json.JSONDecodeError, ValueError) as e:
            print(f"Warning: Could not parse existing JSON ({e}). Starting fresh.")

    # 3. Process New Input
    # Iterate through the DataFrame rows
    for _, row in bounded_quantities.iterrows():
        quantity = row['Quantity']

        # Only process known categories
        if quantity in data_map:
            min_val = float(row['Min'])
            max_val = float(row['Max'])

            center = (min_val + max_val) / 2.0
            error = max_val - center

            # Update or insert the data for this qubit
            data_map[quantity][qubit_index] = {'val': center, 'err': error}

    # 4. Rebuild Plotly Traces
    new_traces = []

    for cat in categories:
        qubit_data = data_map[cat]

        # Sort by qubit index (x-axis)
        sorted_qubits = sorted(qubit_data.keys())

        x_list = []
        y_list = []
        error_list = []

        for q in sorted_qubits:
            x_list.append(q)
            y_list.append(qubit_data[q]['val'])
            error_list.append(qubit_data[q]['err'])

        trace = {
            "x": x_list,
            "y": y_list,
            "visible": "legendonly",
            "error_y": {
                "type": "data",
                "array": error_list,
                "visible": True
            },
            "mode": "markers",
            "name": cat,
            "type": "scatter"
        }
        new_traces.append(trace)

    # Define standard layout
    layout = {
        "xaxis": {"title": "Qubit Index", "type": "category"}, # categorical ensures integer ticks
        "yaxis": {"title": "Error Probability", "tickformat": ".1e", "exponentformat": "e", "type": "log"}, # log scale often useful for errors
        "height": 600,
        "legend": {"title": {"text": "Error Type"}},
        "margin": {
            "t": 5,  # <--- Set Top Margin to 10px (was likely 40 or default 80)
            "b": 5, 
            "l": 5, 
            "r": 5
        },
    }

    final_json = {
        "data": new_traces,
        "layout": layout
    }

    # 5. Save
    with open(json_output_path, 'w') as f:
        json.dump(final_json, f, indent=2)

    print(f"Successfully updated {json_output_path} for qubit {qubit_index}.")
1
2
3
4
5
# update_qubit_error_in_json(
#     qubit_index=chosen_qubits[which_qubit_to_analyze],
#     bounded_quantities=bounded_quantities,
#     json_output_path="ibm_kingston.json" if "kingston" in backend.name else "ibm_pittsburgh.json"
# )

see all functions defined here

import inspect, textwrap

def user_functions(namespace=None):
    ns = globals() if namespace is None else namespace
    funcs = []
    for name, obj in ns.items():
        if inspect.isfunction(obj):
            # keep only functions defined in this notebook/kernel
            mod = getattr(obj, "__module__", None)
            if mod in (None, "__main__"):
                funcs.append((name, obj))
    return sorted(funcs, key=lambda x: x[0].lower())

def show_functions_with_docs(namespace=None, width=88):
    for name, fn in user_functions(namespace):
        sig = str(inspect.signature(fn))
        doc = inspect.getdoc(fn) or "(no docstring)"
        doc_wrapped = textwrap.fill(doc, width=width)
        print(f"🟢{name}{sig}\n{doc_wrapped}\n" + "-"*width)

# show_functions_with_docs()

Sandbox / Scratch Pad

# reconstruct_instrument
1
2
3
# gt = random_instrument(fidelity=0.95)

# gt.reveal()
# gauge_transform_instrument_numerically(gt.M0, gt.M1, t=0.155)
# gt = random_instrument(fidelity=0.95)
# gt.reveal()


# GT_inv = summarize_instrument(gt.M0, gt.M1)
# # print(GT_inv)

# # cheating
# sols = reconstruct_instrument_from_invariants_mixed_det(
#     trM0 = gt.M0.trace(),
#     detM0 = np.linalg.det(gt.M0),
#     trM1 = gt.M1.trace(),
#     detM1 = np.linalg.det(gt.M1),
#     S0 = gt.M0.sum(),
#     gauge_p00 = 1.0
# )

# print(f"\nthere are {len(sols)} solutions from empirical data.")
# chosen_sol_index = 0  # Choose which solution to analyze

# reconstructed_MCM, _ = plot_gauge_transformation_effects(
#     Instrument2x2(M0=sols[chosen_sol_index][0], M1=sols[chosen_sol_index][1]), 
#     t_width_factor=1.5, 
#     verbose=False,
#     MCM_reference=[gt],
#     resolution=1e5
#     )

# print("Original Instrument:")
# gt.reveal()

# print("\nReconstructed Instrument from invariants (empirical data):")
# if reconstructed_MCM is not None:
#     reconstructed_MCM.reveal()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
import json

def plot_gauge_transformation_effects(
    MCM_to_transform, 
    t_width_factor: float = 1.20, 
    verbose: bool = False, 
    MCM_reference: List[Instrument2x2] = None,
    resolution: float = 1e5,
    p00_min: float = 0.5,
    margin_tol: float = 0.0,
    render_plots: bool = True,
    json_output_path: str = None
):
    """
    Analyzes and plots the effect of a gauge transformation on an instrument.

    This function takes an instrument, applies a gauge transformation over a range
    of the gauge parameter 't', and plots how each of the 8 matrix entries evolves.
    It uses analytical methods to determine valid gauge parameter regions.

    Note: t=0.5 is excluded from the analysis as the gauge transformation matrix
    is non-invertible at that point.

    Args:
        MCM_to_transform: An Instrument2x2 object to be transformed.
        t_width_factor: Factor to scale the plotting range around valid regions. Default 1.20.
                        1.00 means plot exactly the valid regions, >1.00 adds padding.
        verbose: If True, print detailed information about valid intervals. Default False.
        MCM_reference: List of reference Instrument2x2 objects to compare RMSE against. Default [].
        resolution: Number of points to sample per unit t-range.
        p00_min: Minimum value for M0[0,0] when identifying focus regions. Default 0.5.
        margin_tol: Margin tolerance for allowed regions. Default 0.0.
        render_plots: If True, generate and show plots. If False, skip plotting. Default True.
        json_output_path: If provided, saves a JSON file containing Plotly data/layout for the first two subplots.

    Returns:
        Tuple containing:
            - Instrument2x2 object constructed from center points of valid entry ranges, or None if no valid ranges.
            - List of valid t-regions for the center_instrument (relative to itself).
    """
    if MCM_reference is None:
        MCM_reference = []

    # Helper for formatting value column
    def fmt_val_err(min_v, max_v):
        c = (min_v + max_v) / 2
        h = (max_v - min_v) / 2

        # Determine exponent from the larger of abs(c) or abs(h) to avoid tiny numbers if c is near zero
        ref = abs(c) if abs(c) > 0 else abs(h)
        if ref == 0:
            return "(0.000 ± 0.000)e+0"

        exponent = int(np.floor(np.log10(ref)))
        scale = 10.0 ** (-exponent)

        c_s = c * scale
        h_s = h * scale

        return f"({c_s:.3f} ± {h_s:.3f})e{exponent:+d}"

    if verbose:
        print("Original Instrument to be transformed:")
        MCM_to_transform.reveal()

    # Use analytical method to find valid t-regions
    valid_t_regions = allowed_t_regions_for_list(
        [MCM_to_transform.M0, MCM_to_transform.M1], 
        tol=1e-24,
        margin_tol=margin_tol
    )

    if not valid_t_regions:
        raise ValueError(f"No valid gauge parameter regions found for this instrument (margin_tol={margin_tol}).")

    if verbose:
        print(f"\nAnalytically determined valid t-regions (total: {len(valid_t_regions)}) with margin {margin_tol}:")
        for i, (t_lo, t_hi) in enumerate(valid_t_regions, 1):
            print(f"  Region {i}: t ∈ [{t_lo:.6f}, {t_hi:.6f}] (width: {t_hi - t_lo:.6f})")

    # Determine plotting range based on valid regions and width factor
    all_t_mins = [r[0] for r in valid_t_regions if not np.isinf(r[0])]
    all_t_maxs = [r[1] for r in valid_t_regions if not np.isinf(r[1])]

    if all_t_mins and all_t_maxs:
        t_plot_min = min(all_t_mins)
        t_plot_max = max(all_t_maxs)
        t_center = (t_plot_min + t_plot_max) / 2
        t_half_span = (t_plot_max - t_plot_min) / 2

        # Apply width factor
        t_plot_min = t_center - t_half_span * t_width_factor
        t_plot_max = t_center + t_half_span * t_width_factor
    else:
        # Fallback if regions are unbounded
        t_plot_min = -0.5
        t_plot_max = 1.5

    # Ensure we don't include t=0.5 in our sampling
    if abs(t_plot_min - 0.5) < 1e-6:
        t_plot_min = 0.5 - 1e-6
    if abs(t_plot_max - 0.5) < 1e-6:
        t_plot_max = 0.5 + 1e-6

    # Generate t values for plotting, excluding t=0.5
    n_points = int(resolution * (t_plot_max - t_plot_min))
    if t_plot_min < 0.5 < t_plot_max:
        t_values_left = np.linspace(t_plot_min, 0.5 - 1e-6, n_points // 2)
        t_values_right = np.linspace(0.5 + 1e-6, t_plot_max, n_points // 2)
        t_values = np.concatenate([t_values_left, t_values_right])
    elif t_plot_max < 0.5:
        t_values = np.linspace(t_plot_min, t_plot_max, n_points)
    else:
        t_values = np.linspace(t_plot_min, t_plot_max, n_points)

    # Store the 8 entries of the transformed instrument for each value of t
    transformed_entries = []

    for t_val in t_values:
        M0_prime, M1_prime = gauge_transform_instrument_numerically(
            MCM_to_transform.M0, MCM_to_transform.M1, t_val
        )
        entries = np.concatenate((M0_prime.flatten(), M1_prime.flatten()))
        transformed_entries.append(entries)

    transformed_entries = np.array(transformed_entries)

    # --- Generate dedicated statistics samples from the first valid region ---
    # This ensures table values are independent of plotting width factor.
    stats_entries = None
    if len(valid_t_regions) > 0:
        t_lo_stat, t_hi_stat = valid_t_regions[0]

        # Handle potential infinite bounds for stats (clip to reasonable range if needed)
        t_start_stat = t_lo_stat if not np.isinf(t_lo_stat) else -5.0
        t_end_stat = t_hi_stat if not np.isinf(t_hi_stat) else 5.0

        width_stat = t_end_stat - t_start_stat
        # Ensure sufficient points for statistics
        n_stats = max(200, int(resolution * width_stat))
        t_stats = np.linspace(t_start_stat, t_end_stat, n_stats)

        stats_entries_list = []
        for t_val in t_stats:
            M0_p, M1_p = gauge_transform_instrument_numerically(
                MCM_to_transform.M0, MCM_to_transform.M1, t_val
            )
            stats_entries_list.append(np.concatenate((M0_p.flatten(), M1_p.flatten())))
        stats_entries = np.array(stats_entries_list)
    # -----------------------------------------------------------------------

    # Use rebase_and_anchor_instrument to find focus regions with ok_p00
    rebased_results = rebase_and_anchor_instrument(
        MCM_to_transform.M0, 
        MCM_to_transform.M1, 
        valid_t_regions, 
        p00_min=p00_min
    )

    # Find first region with ok_p00 for focus plot
    focus_region_info = None
    for res in rebased_results:
        if res['ok_p00']:
            focus_region_info = res
            break

    # Determine number of subplots
    has_focus_plot = focus_region_info is not None or len(MCM_reference) > 0

    if render_plots:
        if has_focus_plot:
            fig = plt.figure(figsize=(14, 28))
            gs = fig.add_gridspec(4, 1, height_ratios=[1.2, 1.2, 1.2, 2.4], hspace=0.15)

            ax1 = fig.add_subplot(gs[0])
            ax2 = fig.add_subplot(gs[1], sharex=ax1)
            axes = [ax1, ax2]

            if len(MCM_reference) > 0:
                ax3 = fig.add_subplot(gs[2], sharex=ax1)
                axes.append(ax3)

            ax4 = fig.add_subplot(gs[3])
            axes.append(ax4)
        else:
            fig, axes = plt.subplots(2, 1, figsize=(14, 14), sharex=True)
            if isinstance(axes, plt.Axes):
                axes = [axes]

        ax1 = axes[0]
        ax2 = axes[1]

        # ===== First subplot: Individual matrix entries =====
        labels = [
            r"$M^0$[0,0] = $p_0^{(0,0)}$", 
            r"$M^0$[0,1] = $p_1^{(0,0)}$", 
            r"$M^0$[1,0] = $p_0^{(0,1)}$",
            r"$M^0$[1,1] = $p_1^{(0,1)}$",
            r"$M^1$[0,0] = $p_0^{(1,0)}$",
            r"$M^1$[0,1] = $p_1^{(1,0)}$",
            r"$M^1$[1,0] = $p_0^{(1,1)}$",
            r"$M^1$[1,1] = $p_1^{(1,1)}$"
        ]

        for i in range(8):
            ax1.plot(t_values, transformed_entries[:, i], linewidth=1.0, label=labels[i])

        ax1.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')
        ax1.axhline(y=0, color='k', linestyle='--', linewidth=1.0, label='Prob Boundary (0,1)')
        ax1.axhline(y=1, color='k', linestyle='--', linewidth=1.0)

        if margin_tol > 0:
            ax1.axhline(y=-margin_tol, color='gray', linestyle=':', linewidth=1.0, label=f'Margin (±{margin_tol})')
            ax1.axhline(y=1+margin_tol, color='gray', linestyle=':', linewidth=1.0)

        # Highlight valid regions using analytical results
        for i, (t_lo, t_hi) in enumerate(valid_t_regions):
            # Clip to plotting range
            t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
            t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max

            if t_lo_plot < t_hi_plot:
                label = 'Valid Gauge Region (t)' if i == 0 else ""
                ax1.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)

        ax1.set_ylabel("Value of Instrument Matrix Entry")
        ax1.set_title(f"Evolution of Instrument Entries (Valid Interval: [{-margin_tol}, {1+margin_tol}])")
        ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        ax1.grid(True, linestyle=':', alpha=0.6)
        ax1.set_ylim(max(np.min(transformed_entries), -0.1 - margin_tol), min(np.max(transformed_entries), 1.1 + margin_tol))

    # ===== Compute valid ranges for matrix entries =====
    entry_labels = [
        "M^0[0,0] = p_0^(0,0)",
        "M^0[0,1] = p_1^(0,0)",
        "M^0[1,0] = p_0^(0,1)",
        "M^0[1,1] = p_1^(0,1)",
        "M^1[0,0] = p_0^(1,0)",
        "M^1[0,1] = p_1^(1,0)",
        "M^1[1,0] = p_0^(1,1)",
        "M^1[1,1] = p_1^(1,1)"
    ]

    entry_ranges_data = []
    center_values = []  # Store center values for constructing return instrument

    for entry_idx, entry_label in enumerate(entry_labels):
        if stats_entries is not None:
            block_values = stats_entries[:, entry_idx]
            min_val = np.min(block_values)
            max_val = np.max(block_values)
            width = max_val - min_val
            center = (min_val + max_val) / 2

            center_values.append(center)

            row_data = {
                'Entry': entry_label,
                'Min': f"{min_val:.8f}",
                'Max': f"{max_val:.8f}",
                'Width': f"{width:.5e}",
                'Value': fmt_val_err(min_val, max_val)
            }

            # Add comparison columns for each reference instrument
            for ref_idx, ref_inst in enumerate(MCM_reference):
                ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
                ref_val = ref_entries[entry_idx]

                # Calculate absolute discrepancy from center
                abs_disc = ref_val - center

                # Check if reference value is within range
                if min_val <= ref_val <= max_val:
                    row_data[f'Ref{ref_idx+1}'] = f"IN ({abs_disc:+.5e})"
                else:
                    direction = "above" if ref_val > max_val else "below"
                    row_data[f'Ref{ref_idx+1}'] = f"OUT {direction} ({abs_disc:+.5e})"

            entry_ranges_data.append(row_data)

    # Construct Instrument2x2 from center values
    center_instrument = None
    if len(center_values) == 8:
        M0_center = np.array([[center_values[0], center_values[1]],
                              [center_values[2], center_values[3]]], dtype=float)
        M1_center = np.array([[center_values[4], center_values[5]],
                              [center_values[6], center_values[7]]], dtype=float)
        center_instrument = Instrument2x2(M0=M0_center, M1=M1_center)

    # ===== Second subplot: Derived quantities =====
    prep0_meas1 = transformed_entries[:, 4] + transformed_entries[:, 6]
    prep1_meas0 = transformed_entries[:, 1] + transformed_entries[:, 3]
    prep0_excite = transformed_entries[:, 2] + transformed_entries[:, 6]
    prep1_decay = transformed_entries[:, 1] + transformed_entries[:, 5]

    derived_quantities = {
        "prep 0 meas 1": prep0_meas1,
        "prep 1 meas 0": prep1_meas0,
        "prep 0 excite to 1": prep0_excite,
        "prep 1 decay to 0": prep1_decay
    }

    # Calculate derived quantities for stats
    stats_derived = {}
    if stats_entries is not None:
        s_prep0_meas1 = stats_entries[:, 4] + stats_entries[:, 6]
        s_prep1_meas0 = stats_entries[:, 1] + stats_entries[:, 3]
        s_prep0_excite = stats_entries[:, 2] + stats_entries[:, 6]
        s_prep1_decay = stats_entries[:, 1] + stats_entries[:, 5]

        stats_derived = {
            "prep 0 meas 1": s_prep0_meas1,
            "prep 1 meas 0": s_prep1_meas0,
            "prep 0 excite to 1": s_prep0_excite,
            "prep 1 decay to 0": s_prep1_decay
        }

    quantity_valid_ranges = {}
    quantity_ranges_data = []
    for quantity_name, quantity_values in derived_quantities.items():
        if quantity_name in stats_derived:
            block_values = stats_derived[quantity_name]

            min_val = np.min(block_values)
            max_val = np.max(block_values)
            width = max_val - min_val
            center = (min_val + max_val) / 2
            quantity_valid_ranges[quantity_name] = (min_val, max_val)

            row_data = {
                'Quantity': quantity_name,
                'Min': f"{min_val:.8f}",
                'Max': f"{max_val:.8f}",
                'Width': f"{width:.5e}",
                'Value': fmt_val_err(min_val, max_val)
            }

            # Add comparison columns for each reference instrument
            for ref_idx, ref_inst in enumerate(MCM_reference):
                # Compute reference quantity value
                if quantity_name == "prep 0 meas 1":
                    ref_val = ref_inst.M1[0, 0] + ref_inst.M1[1, 0]
                elif quantity_name == "prep 1 meas 0":
                    ref_val = ref_inst.M0[0, 1] + ref_inst.M0[1, 1]
                elif quantity_name == "prep 0 excite to 1":
                    ref_val = ref_inst.M0[1, 0] + ref_inst.M1[1, 0]
                elif quantity_name == "prep 1 decay to 0":
                    ref_val = ref_inst.M0[0, 1] + ref_inst.M1[0, 1]
                else:
                    ref_val = 0.0

                # Calculate absolute discrepancy from center
                abs_disc = ref_val - center

                # Check if reference value is within range
                if min_val <= ref_val <= max_val:
                    row_data[f'Ref{ref_idx+1}'] = f"IN ({abs_disc:+.5e})"
                else:
                    direction = "above" if ref_val > max_val else "below"
                    row_data[f'Ref{ref_idx+1}'] = f"OUT {direction} ({abs_disc:+.5e})"

            quantity_ranges_data.append(row_data)

    if render_plots:
        # Plot the derived quantities
        labels_with_ranges = [
            (r"prep 0 meas 1: $p_0^{(1,0)} + p_0^{(1,1)}$", "prep 0 meas 1"),
            (r"prep 1 meas 0: $p_1^{(0,0)} + p_1^{(0,1)}$", "prep 1 meas 0"),
            (r"prep 0 excite to 1: $p_0^{(0,1)} + p_0^{(1,1)}$", "prep 0 excite to 1"),
            (r"prep 1 decay to 0: $p_1^{(0,0)} + p_1^{(1,0)}$", "prep 1 decay to 0")
        ]

        quantity_list = list(derived_quantities.items())
        for idx, ((base_label, quantity_key), (quantity_name, quantity_values)) in enumerate(zip(labels_with_ranges, quantity_list)):
            if quantity_key in quantity_valid_ranges:
                min_val, max_val = quantity_valid_ranges[quantity_key]
                label = f"{base_label}\n∈ [{min_val:.6f}, {max_val:.6f}]"
            else:
                label = base_label
            ax2.plot(t_values, quantity_values, label=label, linewidth=1)

        ax2.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')
        ax2.axhline(y=0, color='k', linestyle='--', linewidth=1.0, label='Prob Boundary (0,1)')
        ax2.axhline(y=1, color='k', linestyle='--', linewidth=1.0)

        if margin_tol > 0:
            ax2.axhline(y=-margin_tol, color='gray', linestyle=':', linewidth=1.0, label=f'Margin (±{margin_tol})')
            ax2.axhline(y=1+margin_tol, color='gray', linestyle=':', linewidth=1.0)

        for i, (t_lo, t_hi) in enumerate(valid_t_regions):
            t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
            t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max

            if t_lo_plot < t_hi_plot:
                label = 'Valid Gauge Region (t)' if i == 0 else ""
                ax2.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)

        ax2.set_xlabel(r"Gauge Parameter $(t)$")
        ax2.set_ylabel("Derived Quantity Value")
        ax2.set_title(f"Derived Quantities (Valid Interval: [{-margin_tol}, {1+margin_tol}])")

        legend = ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=9, 
                            labelspacing=1.2, handlelength=2)
        ax2.grid(True, linestyle=':', alpha=0.6)
        ax2.set_ylim(-0.1 - margin_tol, 1.1 + margin_tol)

    # ===== Display DataFrames for valid intervals =====
    print("="*80)
    print("Valid Intervals for Derived Quantities (Readout & Back-action Errors)")
    print("="*80)
    if quantity_ranges_data:
        df_quantities = pd.DataFrame(quantity_ranges_data)
        display(df_quantities)
    else:
        print("No valid intervals found for derived quantities.")

    print("\n" + "="*80)
    print("Valid Intervals for Matrix Entries")
    print("="*80)
    if entry_ranges_data:
        df_entries = pd.DataFrame(entry_ranges_data)
        display(df_entries)
    else:
        print("No valid intervals found for matrix entries.")

    if len(MCM_reference) > 0:
        print("\nNote: Reference comparison format:")
        print("  'IN (±X.XXXe±YY)' - value is within range, absolute discrepancy from center")
        print("  'OUT above/below (±X.XXXe±YY)' - value is outside range, absolute discrepancy from center")
    print("="*80 + "\n")

    # ===== Third subplot: RMSE to reference instruments =====
    best_match_instruments = []

    if len(MCM_reference) > 0:
        if render_plots:
            ax3 = axes[2]

        for ref_idx, ref_inst in enumerate(MCM_reference):
            ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
            rmse_values = np.sqrt(np.mean((transformed_entries - ref_entries)**2, axis=1))

            min_rmse_idx = np.argmin(rmse_values)
            min_rmse = rmse_values[min_rmse_idx]
            t_min_rmse = t_values[min_rmse_idx]

            if render_plots:
                label = f"Ref {ref_idx+1}: min RMSE={min_rmse:.6e} at t={t_min_rmse:.4f}"
                ax3.plot(t_values, rmse_values, label=label, linewidth=1.5)

                ax3.plot(t_min_rmse, min_rmse, 'o', markersize=8)

            M0_best, M1_best = gauge_transform_instrument_numerically(
                MCM_to_transform.M0, MCM_to_transform.M1, t_min_rmse
            )
            best_match_instruments.append(Instrument2x2(M0=M0_best, M1=M1_best))

        if render_plots:
            ax3.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')

            for i, (t_lo, t_hi) in enumerate(valid_t_regions):
                t_lo_plot = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
                t_hi_plot = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max

                if t_lo_plot < t_hi_plot:
                    label = 'Valid Gauge Region (t)' if i == 0 else ""
                    ax3.axvspan(t_lo_plot, t_hi_plot, color='green', alpha=0.2, label=label)

            ax3.set_ylabel("RMSE to Reference")
            ax3.set_title("RMSE Between Gauge-Transformed and Reference Instruments")
            ax3.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=9)
            ax3.grid(True, linestyle=':', alpha=0.6)
            ax3.set_yscale('log')

    # ===== Fourth subplot: Focused view =====
    if has_focus_plot and focus_region_info is not None:
        if render_plots:
            ax4 = axes[-1]

        # Use the adjusted region from rebase_and_anchor_instrument
        anchor_t = focus_region_info['anchor_t']
        s_min, s_max = focus_region_info['adjusted_region']

        # Convert back to absolute t values
        # t3 = t1 + s - 2*t1*s => given t1=anchor_t, s in [s_min, s_max]
        t_focus_exact_min = anchor_t + s_min - 2*anchor_t*s_min
        t_focus_exact_max = anchor_t + s_max - 2*anchor_t*s_max

        if t_focus_exact_min > t_focus_exact_max:
            t_focus_exact_min, t_focus_exact_max = t_focus_exact_max, t_focus_exact_min

        # Apply t_width_factor to the focus region
        t_focus_center = (t_focus_exact_min + t_focus_exact_max) / 2
        t_focus_half_span = (t_focus_exact_max - t_focus_exact_min) / 2

        t_focus_min = t_focus_center - t_focus_half_span * t_width_factor
        t_focus_max = t_focus_center + t_focus_half_span * t_width_factor

        if verbose:
            print(f"\nFocus region: anchor_t={anchor_t:.6f}, "
                  f"local s∈[{s_min:.6f}, {s_max:.6f}]")
            print(f"  Exact valid t∈[{t_focus_exact_min:.6f}, {t_focus_exact_max:.6f}]")
            print(f"  Plotted t∈[{t_focus_min:.6f}, {t_focus_max:.6f}] (with factor {t_width_factor:.2f})")

        if render_plots:
            focus_mask = (t_values >= t_focus_min) & (t_values <= t_focus_max)
            t_focus = t_values[focus_mask]
            entries_focus = transformed_entries[focus_mask]

            for i in range(8):
                ax4.plot(t_focus, entries_focus[:, i], linewidth=0.8, alpha=0.5, label=labels[i])

            for quantity_name, quantity_values in derived_quantities.items():
                ax4.plot(t_focus, quantity_values[focus_mask], linewidth=1.5, label=quantity_name)

            if len(MCM_reference) > 0:
                for ref_idx, ref_inst in enumerate(MCM_reference):
                    ref_entries = np.concatenate((ref_inst.M0.flatten(), ref_inst.M1.flatten()))
                    rmse_focus = np.sqrt(np.mean((entries_focus - ref_entries)**2, axis=1))
                    rmse_normalized = rmse_focus / (rmse_focus.max() + 1e-18)
                    ax4.plot(t_focus, rmse_normalized, linewidth=2, linestyle='--', 
                            label=f"Ref {ref_idx+1} RMSE (normalized)\n original max={rmse_focus.max():.2e}")

            if 0.5 >= t_focus_min and 0.5 <= t_focus_max:
                ax4.axvline(x=0.5, color='red', linestyle=':', linewidth=1.0, label='Singularity (t=0.5)')

            ax4.axhline(y=0, color='k', linestyle='--', linewidth=1.0, alpha=0.5)
            ax4.axhline(y=1, color='k', linestyle='--', linewidth=1.0, alpha=0.5)

            # Plot valid regions - only those overlapping with focus window
            for i, (t_lo, t_hi) in enumerate(valid_t_regions):
                if t_hi >= t_focus_min and t_lo <= t_focus_max:
                    plot_start = max(t_lo, t_focus_min)
                    plot_end = min(t_hi, t_focus_max)
                    label = 'Valid Gauge Region (t)' if i == 0 else ""
                    ax4.axvspan(float(plot_start), float(plot_end), color='green', alpha=0.2, label=label)

            ax4.set_xlabel(r"Gauge Parameter $(t)$ [Focused View]")
            ax4.set_ylabel("Quantity Values")
            ax4.set_title(f"Focused View: t ∈ [{t_focus_min:.4f}, {t_focus_max:.4f}]\n"
                        f"Valid region: [{t_focus_exact_min:.4f}, {t_focus_exact_max:.4f}] "
                        f"(width factor: {t_width_factor:.2f})")
            ax4.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=8, ncol=2)
            ax4.grid(True, linestyle=':', alpha=0.6)
            ax4.set_ylim(-0.1, 1.1)
            ax4.set_xlim(t_focus_min, t_focus_max)
    else:
        if render_plots:
            axes[-1].set_xlabel(r"Gauge Parameter $(t)$")

    if render_plots:
        plt.tight_layout()
        plt.show()

    # ===== JSON Export Logic =====
    if json_output_path:
        # Prepare data for Plotly
        plotly_data = []

        # Helper to sanitize numpy arrays for JSON
        def sanitize_array(arr):
            # Replace infinity/nan with None (null in JSON)
            # Also handle the singularity at t=0.5 explicitly if needed, 
            # though numpy usually handles division by zero as inf/nan.
            # We convert to list and replace non-finite values.
            l = []
            for x in arr:
                if np.isfinite(x):
                    l.append(float(x))
                else:
                    l.append(None)
            return l

        t_list = sanitize_array(t_values)

        # 1. Matrix Entries (Subplot 1 -> xaxis, yaxis)
        entry_names = [
            "M^0[0,0]", "M^0[0,1]", "M^0[1,0]", "M^0[1,1]",
            "M^1[0,0]", "M^1[0,1]", "M^1[1,0]", "M^1[1,1]"
        ]
        for i in range(8):
            trace = {
                "x": t_list,
                "y": sanitize_array(transformed_entries[:, i]),
                "type": "scatter",
                "mode": "lines",
                "name": entry_names[i],
                "visible": "legendonly",
                "xaxis": "x",
                "yaxis": "y",
                "legendgroup": "entries"
            }
            plotly_data.append(trace)

        # 2. Derived Quantities (Subplot 2 -> xaxis2, yaxis2)
        for q_name, q_vals in derived_quantities.items():
            trace = {
                "x": t_list,
                "y": sanitize_array(q_vals),
                "type": "scatter",
                "mode": "lines",
                "name": q_name,
                "visible": "legendonly",
                "xaxis": "x2",
                "yaxis": "y2",
                "legendgroup": "derived"
            }
            plotly_data.append(trace)

        # 3. Valid Regions as Traces (Interactive)
        # We create a filled area for each valid region.
        # To make it a "region", we can define a polygon or just a filled line trace.
        # A simple way is to use a scatter trace that goes (t_lo, -10) -> (t_lo, 10) -> (t_hi, 10) -> (t_hi, -10) -> close
        # But 'fill' in plotly usually fills to zero or next trace.
        # Better approach for "background" regions in traces: use 'fill="toself"' with a closed path.

        # We'll add one trace per region per subplot so they show up on both.
        # Or just one trace per region and map to both axes? Plotly doesn't support one trace on multiple axes easily.
        # We will add them to both subplots.

        # Y-bounds for the shaded region (large enough to cover the view)
        y_min_shade = -0.2
        y_max_shade = 1.2

        for i, (t_lo, t_hi) in enumerate(valid_t_regions):
            # Clip to plot range
            r_start = max(t_lo, t_plot_min) if not np.isinf(t_lo) else t_plot_min
            r_end = min(t_hi, t_plot_max) if not np.isinf(t_hi) else t_plot_max

            if r_start >= r_end:
                continue

            # Create a closed polygon for the region
            # (x0, y0) -> (x0, y1) -> (x1, y1) -> (x1, y0) -> (x0, y0)
            x_poly = [r_start, r_start, r_end, r_end, r_start]
            y_poly = [y_min_shade, y_max_shade, y_max_shade, y_min_shade, y_min_shade]

            common_legend_group = "valid_regions"

            # Show legend ONLY for the very first region of the loop
            show_legend_flag = True if i == 0 else False 

            # Add to Subplot 1
            plotly_data.append({
                "x": x_poly,
                "y": y_poly,
                "type": "scatter",
                "mode": "lines",
                "fill": "toself",
                "visible": "legendonly",
                "fillcolor": "rgba(0, 128, 0, 0.1)",
                "line": {"width": 0},
                "name": "Valid Gauge Region", # Unified name
                "xaxis": "x",
                "yaxis": "y",
                "legendgroup": common_legend_group, # Linked!
                "showlegend": show_legend_flag
            })

            # Add to Subplot 2
            plotly_data.append({
                "x": x_poly,
                "y": y_poly,
                "type": "scatter",
                "mode": "lines",
                "fill": "toself",
                "visible": "legendonly",
                "fillcolor": "rgba(0, 128, 0, 0.1)",
                "line": {"width": 0},
                "name": "Valid Gauge Region",
                "xaxis": "x2",
                "yaxis": "y2",
                "legendgroup": common_legend_group, # Linked!
                "showlegend": False # Never show legend for subplot 2 duplicates
            })

        # 4. Probability Boundaries (0 and 1) as Traces
        # We create a line from t_plot_min to t_plot_max
        x_line = [t_plot_min, t_plot_max]

        for y_val in [0.0, 1.0]:
            unique_group = f"boundary_{y_val}"

            # Subplot 1
            plotly_data.append({
                "x": x_line,
                "y": [y_val, y_val],
                "type": "scatter",
                "mode": "lines",
                "visible": "legendonly",
                "line": {"color": "black", "dash": "dash", "width": 1},
                "name": f"Prob Boundary {y_val}",
                "xaxis": "x",
                "yaxis": "y",
                "legendgroup": unique_group, # Unique group!
                "showlegend": True # Show legend for this boundary
            })
            # Subplot 2 (Duplicate line, linked to same unique group)
            plotly_data.append({
                "x": x_line,
                "y": [y_val, y_val],
                "type": "scatter",
                "mode": "lines",
                "visible": "legendonly",
                "line": {"color": "black", "dash": "dash", "width": 1},
                "name": f"Prob Boundary {y_val}",
                "xaxis": "x2",
                "yaxis": "y2",
                "legendgroup": unique_group, # Linked to the one above!
                "showlegend": False # Don't duplicate in legend
            })

        # Layout
        layout = {
            "grid": {"rows": 2, "columns": 1, "pattern": "independent", "ygap": 0.05},
            "xaxis": {"title": "Gauge Parameter (t)", "anchor": "y"},
            "yaxis": {"title": "Matrix Entries", "anchor": "x", "range": [-0.2, 1.2]},
            "xaxis2": {"title": "Gauge Parameter (t)", "anchor": "y2"},
            "yaxis2": {"title": "Derived Quantities", "anchor": "x2", "range": [-0.2, 1.2]},
            "height": 450,
            "width": 1000,
            "showlegend": True,
            "margin": {
                "t": 0,  # Top margin (reduced from ~100)
                "b": 0,  # Bottom margin
                "l": 20,  # Left margin (for axis labels)
                "r": 20   # Right margin
            }
        }

        # Save to file
        with open(json_output_path, 'w') as f:
            json.dump({"data": plotly_data, "layout": layout}, f, indent=2)

        if verbose:
            print(f"Plotly JSON saved to {json_output_path}")

    center_t_regions = []
    if center_instrument is not None:
        center_t_regions = allowed_t_regions_for_list(
            [center_instrument.M0, center_instrument.M1], 
            tol=1e-24,
            margin_tol=margin_tol
        )
    # filter the center_t_regions such that the regions that has the smallest sum of absolute values for 2 end points is chosen
    if center_t_regions:
        center_t_regions = sorted(center_t_regions, key=lambda x: abs(x[0]) + abs(x[1]))
        center_t_regions = center_t_regions[0]

    return center_instrument, center_t_regions
1
2
3
4
5
6
7
8
# reconstructed_MCM, _ = plot_gauge_transformation_effects(
#     Instrument2x2(M0=sols[chosen_sol_index][0], M1=sols[chosen_sol_index][1]), 
#     t_width_factor=1.5, 
#     verbose=False,
#     MCM_reference=[gt],
#     resolution=1e3,
#     # json_output_path="/Users/trainerblade/Documents/03_mySlides/06_APS_2026_March_Meeting/public/gauge.json"
#     )
# ibm_prob_list
# emp_inv = derived_constraints_from_empirical_probs(ibm_prob_list)
# print(emp_inv)

# sols = reconstruct_instrument_from_invariants_mixed_det(
#     trM0 = emp_inv[0],
#     detM0 = emp_inv[1], 
#     trM1 = emp_inv[2],
#     detM1 = emp_inv[3],
#     S0 = emp_inv[6] * 2, # S0 = sum of all elements of P0 = 2 * probability of observing 0 from v0=(0.5,0.5)
#     gauge_p00 = 1.0
# )

crop svg

import xml.etree.ElementTree as ET
import glob
import os

def crop_svg(file_path, output_path, crop_box):
    """
    Crops an SVG by modifying its viewBox.

    Args:
        file_path (str): Path to input SVG.
        output_path (str): Path to save cropped SVG.
        crop_box (tuple): (min_x, min_y, width, height)
    """
    ET.register_namespace('', "http://www.w3.org/2000/svg")
    ET.register_namespace('xlink', "http://www.w3.org/1999/xlink")

    tree = ET.parse(file_path)
    root = tree.getroot()

    # 1. Update the viewBox to "zoom in" on the desired area
    # Format: "min_x min_y width height"
    new_viewbox = f"{crop_box[0]} {crop_box[1]} {crop_box[2]} {crop_box[3]}"
    root.set('viewBox', new_viewbox)

    # 2. Update width and height to match the crop ratio (optional but recommended)
    # This ensures the SVG displays at the correct size in slides/browsers
    root.set('width', str(crop_box[2]))
    root.set('height', str(crop_box[3]))

    tree.write(output_path)
    print(f"Saved: {output_path}")

# ==========================================
# 🔧 CONFIGURATION (Adjust these values!)
# ==========================================

# 1. Define your crop area: (x, y, width, height)
# You will need to trial-and-error these numbers.
# Tip: Open your original SVG in a text editor or browser inspector 
# to see the original 'width' and 'height' to get a baseline.
CROP_SETTINGS = (0, 0, 1000, 330) 

# 2. Input pattern (e.g., all svgs in current folder)
INPUT_FILES = glob.glob("/Users/trainerblade/Downloads/*.svg")
# output_dir = "cropped_svgs"
# os.makedirs(output_dir, exist_ok=True)

# for f_path in INPUT_FILES:
#     file_name = os.path.basename(f_path)
#     out_path = os.path.join(output_dir, f"cropped_{file_name}")

#     crop_svg(f_path, out_path, CROP_SETTINGS)

# print("\nDone! Check the 'cropped_svgs' folder.")