Minimum Varianz Fusion - Musterlösung

  1import numpy as np
  2import cv2
  3from misc import (
  4    IMAGE_SHAPE,
  5    SAMPLES_PER_CLUSTER,
  6    draw_cluster,
  7    draw_mahalanobis,
  8    draw_text,
  9    draw_axes,
 10)
 11
 12
 13def minimum_variance_fusion(mu1, cov1, mu2, cov2):
 14    """
 15    **TODO**: Calculate the minimum variance fusion result for
 16    two normal distributed measurements mu1 and mu2 and their
 17    respective covariances cov1 and cov2. Return both the fused
 18    measurement mu as well as the resulting covariance.
 19
 20    :param mu1: First measurement vector
 21    :param mu2: Second measurement vector
 22    :param cov1: Covariance of first measurement
 23    :param cov2: Covariance of second measurement
 24    :return: Tuple (mu, cov) containing resulting measurement and covariance of the result.
 25    """
 26    inv1 = np.linalg.inv(cov1)
 27    inv2 = np.linalg.inv(cov2)
 28    cov = np.linalg.inv(inv1 + inv2)
 29    mu = cov @ (inv1 @ mu1 + inv2 @ mu2)
 30
 31    return mu, cov
 32
 33
 34if __name__ == "__main__":
 35    # Create a correlated multivariate normal distribution
 36    mu1 = np.array([0.0, 0.0])
 37    cov1 = np.array([[1.0, 0.0], [0.0, 1.0]])
 38
 39    mu2 = np.array([2.0, 2.0])
 40    cov2 = np.array([[1.0, 0.0], [0.0, 1.0]])
 41
 42    control = 1
 43
 44    while True:
 45        image = np.ones(IMAGE_SHAPE)
 46
 47        mu, cov = minimum_variance_fusion(mu1, cov1, mu2, cov2)
 48
 49        cluster1 = np.random.multivariate_normal(mu1, cov1, size=SAMPLES_PER_CLUSTER)
 50        cluster2 = np.random.multivariate_normal(mu2, cov2, size=SAMPLES_PER_CLUSTER)
 51        cluster3 = np.random.multivariate_normal(mu, cov, size=SAMPLES_PER_CLUSTER)
 52
 53        draw_cluster(image, cluster1, col=(0.7, 0.8, 1.0))
 54        draw_cluster(image, cluster2, col=(1.00, 0.8, 0.7))
 55        draw_cluster(image, cluster3, col=(0.7, 1.0, 0.7))
 56
 57        draw_mahalanobis(image, mu1, cov1)
 58        draw_mahalanobis(image, mu2, cov2, col=(0.92, 0.14, 0.0))
 59        draw_mahalanobis(image, mu, cov, col=(0.14, 0.92, 0.14))
 60
 61        col = (0.6, 0.6, 0.6)
 62        if control == 1:
 63            col = (0.0, 0.14, 0.92)
 64        draw_text(image, mu1, cov1, col=col)
 65
 66        col = (0.6, 0.6, 0.6)
 67        if control == 2:
 68            col = (0.92, 0.14, 0.0)
 69        draw_text(image, mu2, cov2, yOffset=100, col=col)
 70
 71        draw_text(image, mu, cov, yOffset=200, col=(0.15, 0.92, 0.14))
 72
 73        draw_axes(image)
 74        cv2.imshow("Clusters", image)
 75        key = cv2.waitKey(0)
 76
 77        if key == ord("1"):
 78            control = 1
 79        if key == ord("2"):
 80            control = 2
 81
 82        if key == ord("w"):
 83            if control == 1:
 84                mu1[1] += 0.1
 85            else:
 86                mu2[1] += 0.1
 87        if key == ord("s"):
 88            if control == 1:
 89                mu1[1] -= 0.1
 90            else:
 91                mu2[1] -= 0.1
 92        if key == ord("a"):
 93            if control == 1:
 94                mu1[0] -= 0.1
 95            else:
 96                mu2[0] -= 0.1
 97        if key == ord("d"):
 98            if control == 1:
 99                mu1[0] += 0.1
100            else:
101                mu2[0] += 0.1
102
103        if key == ord("W"):
104            if control == 1:
105                cov1[1][1] += 0.1
106            else:
107                cov2[1][1] += 0.1
108        if key == ord("S"):
109            if control == 1:
110                cov1[1][1] -= 0.1
111            else:
112                cov2[1][1] -= 0.1
113
114        if key == ord("A"):
115            if control == 1:
116                cov1[0][0] -= 0.1
117            else:
118                cov2[0][0] -= 0.1
119        if key == ord("D"):
120            if control == 1:
121                cov1[0][0] += 0.1
122            else:
123                cov2[0][0] += 0.1
124
125        if key == ord("q"):
126            if control == 1:
127                cov1[0][1] += 0.05
128                cov1[1][0] = cov1[0][1]
129            else:
130                cov2[0][1] += 0.05
131                cov2[1][0] = cov2[0][1]
132        if key == ord("e"):
133            if control == 1:
134                cov1[0][1] -= 0.05
135                cov1[1][0] = cov1[0][1]
136            else:
137                cov2[0][1] -= 0.05
138                cov2[1][0] = cov2[0][1]
139
140        if key == 27:
141            break