280 lines
11 KiB
Python
280 lines
11 KiB
Python
import datetime
|
||
from os import error
|
||
import numpy as np
|
||
from config import *
|
||
|
||
def calculate_euclidean_distances(A, B):
|
||
# 计算A和B之间的欧式距离
|
||
distances = np.linalg.norm(A - B, axis=1)
|
||
# 找到最小距离及其索引
|
||
min_distance_index = np.argmin(distances)
|
||
min_distance = distances[min_distance_index]
|
||
return min_distance, min_distance_index
|
||
|
||
def are_lists_equal(listA, listB):
|
||
# 对两个列表中的子列表进行排序
|
||
if len(listA) == 0:
|
||
return False
|
||
sorted_listA = sorted(listA, key=lambda x: (x[0], x[1]))
|
||
sorted_listB = sorted(listB, key=lambda x: (x[0], x[1]))
|
||
# 比较排序后的列表是否相等
|
||
return sorted_listA == sorted_listB
|
||
|
||
def sigmoid(x, a=10, b=0.1):
|
||
# 调整Sigmoid函数使其在x=1时值为0.5
|
||
# a和b是调整参数,用于控制函数的形状
|
||
return 1 / (1 + np.exp(-a * (x - shift_value))) + b
|
||
|
||
|
||
class KalmanFilter:
|
||
def __init__(self, measurement, com_id, measurement_variance=1,process_variance=1e-1):
|
||
current_time = datetime.datetime.now()
|
||
timestamp = int(current_time.timestamp() * 1000000)
|
||
ms = measurement.tolist()
|
||
self.m = np.array([ms[0],ms[1],ms[2],0,0,0]) # 状态量,6维度
|
||
self.origin = [com_id] #origin 表示最强响应
|
||
self.source = self.origin #source 表示所有关联的观测值
|
||
self.survive = np.array(survive_initial) # 初始化生存值
|
||
self.duration = 0
|
||
self.counter = 0
|
||
self.id = str(timestamp % 3600000000 + np.random.randint(1000))
|
||
self.F = [[1,0,0,1,0,0],
|
||
[0,1,0,0,1,0],
|
||
[0,0,1,0,0,1],
|
||
[0,0,0,1,0,0],
|
||
[0,0,0,0,1,0],
|
||
[0,0,0,0,0,1]]
|
||
self.F = np.array(self.F)
|
||
self.H = [[1,0,0,0,0,0],
|
||
[0,1,0,0,0,0],
|
||
[0,0,1,0,0,0]]
|
||
self.H = np.array(self.H)
|
||
self.R = measurement_variance * np.eye(3)
|
||
self.Q = process_variance * np.eye(6)
|
||
self.Q[3, 3] = self.Q[3, 3] * 1e-3
|
||
self.Q[4, 4] = self.Q[4, 4] * 1e-3
|
||
self.Q[5, 5] = self.Q[5, 5] * 1e-3
|
||
self.P = np.eye(6)*0.1
|
||
self.I = np.eye(6)
|
||
self.expend = 1
|
||
self.v = np.array([0,0,0])
|
||
self.born_time = int(current_time.timestamp() * 1000)
|
||
self.latest_update = self.born_time
|
||
|
||
self.m_history = self.m
|
||
self.s_history = []
|
||
self.origin_set = [self.origin]
|
||
|
||
|
||
def predict(self):
|
||
F = self.F
|
||
self.m = np.dot(F,self.m.T) # 简单一步预测模型
|
||
self.m = self.m.T
|
||
self.P = np.dot(np.dot(F,self.P),F.T) + self.Q
|
||
self.survive = self.survive * decay # 应用衰减值
|
||
self.origin_set = np.unique(np.array(self.origin_set), axis=0).tolist() # 计算关联集合
|
||
|
||
def update(self, res, run_timestamp, gate):
|
||
self.duration += 0.6 # 每次更新时,持续时间+0.6
|
||
if len(res['distances']) == 0:
|
||
mmd = 1e8
|
||
else:
|
||
min_distance_index = np.argmin(res['distances'])
|
||
mmd = res['distances'][min_distance_index]
|
||
measurement = res['measurements'][min_distance_index]
|
||
|
||
# 进行更新
|
||
if mmd < gate * self.expend:
|
||
H = self.H
|
||
I = self.I
|
||
self.expend = max(self.expend * 0.8, 1)
|
||
kalman_gain = np.dot(np.dot(self.P,H.T),np.linalg.pinv(np.dot(np.dot(H,self.P),H.T)+self.R))
|
||
self.m += np.dot(kalman_gain,(measurement.T - np.dot(H,self.m.T)))
|
||
self.m = self.m.T
|
||
self.P = np.dot((I - np.dot(kalman_gain,H)),self.P)
|
||
self.origin = [res['key_ids'][min_distance_index]]
|
||
self.counter += 1
|
||
self.survive = sigmoid(self.counter) # 新映射函数
|
||
# 如下操作防止对速度过于自信
|
||
self.P[3, 3] = max(1e-1, self.P[3, 3])
|
||
self.P[4, 4] = max(1e-1, self.P[4, 4])
|
||
self.P[5, 5] = max(1e-1, self.P[5, 5])
|
||
# 截取速度
|
||
self.v = self.m[3:6]
|
||
self.origin_set.append(self.origin)
|
||
self.latest_update = run_timestamp #对时间进行处理
|
||
else:
|
||
self.expend = min(self.expend*1.2,1.5) # 若关联不上,则扩大门限继续搜索
|
||
self.P[3, 3] = min(self.P[3, 3]*1.1,1)
|
||
self.P[4, 4] = min(self.P[4, 4]*1.1,1)
|
||
self.P[5, 5] = min(self.P[5, 5]*1.1,1)
|
||
self.counter -= 1
|
||
self.counter = max(self.counter,0)
|
||
|
||
self.m_history = np.vstack((self.m_history, self.m))
|
||
self.s_history.append(self.survive)
|
||
|
||
def one_correlation(self, data_matrix, id_list):
|
||
# 计算现有数据与data_matrix的差距
|
||
min_distance, min_index = calculate_euclidean_distances(self.m[0:3], data_matrix)
|
||
m_id = id_list[min_index]
|
||
measurement = data_matrix[min_index, :]
|
||
return m_id, min_distance, measurement
|
||
|
||
def correlation(self, sensor_data):
|
||
# 遍历传感器进行计算
|
||
res = {'m_ids':[], 'distances':[], 'measurements':[], 'key_ids':[]}
|
||
for value in sensor_data:
|
||
if len(value['id_list']) > 0:
|
||
m_id, min_distance, measurement = self.one_correlation(value['data_matrix'], value['id_list'])
|
||
key = value['deviceId']
|
||
res['m_ids'].append(m_id)
|
||
res['measurements'].append(measurement)
|
||
res['key_ids'].append([key, m_id])
|
||
# 将发生过关联的目标赋予更大的置信度
|
||
if [key, m_id] in self.origin_set:
|
||
min_distance = min_distance * 0.2
|
||
res['distances'].append(min_distance)
|
||
return res
|
||
|
||
|
||
#融合类的构造函数
|
||
class DataFusion:
|
||
def __init__(self,gate=25,interval = 1,fusion_type = 1,
|
||
measuremrnt_variance=1,process_variance =1e-1):
|
||
"""
|
||
初始化DataFusion类。
|
||
"""
|
||
# self.task_id = task_id
|
||
self.interval = interval
|
||
self.gate = gate
|
||
self.targets = []
|
||
self.fusion_type = fusion_type
|
||
self.existence_thres = 0.01
|
||
self.show_thres = show_thres
|
||
self.process_variance = process_variance
|
||
self.measuremrnt_variance = measuremrnt_variance
|
||
|
||
def set_parameter(self,fusion_parms):
|
||
print("GO!!!!!!!!!")
|
||
print(fusion_parms)
|
||
|
||
def obtain_priority(self,sensor_data):
|
||
self.priority_dict = dict()
|
||
for data in sensor_data:
|
||
if data.get('priority'):
|
||
self.priority_dict[data['deviceId']] = data['priority']
|
||
else:
|
||
self.priority_dict[data['deviceId']] = 1
|
||
|
||
|
||
def out_transformer(self,target):
|
||
out_former = {
|
||
'objectId': target.id,
|
||
'survive': target.survive.tolist(),
|
||
'state': target.m.tolist(),
|
||
'speed': np.linalg.norm(target.v).tolist() / self.interval,
|
||
'source': target.source,
|
||
'sigma': np.diag(target.P).tolist(),
|
||
'X': target.m[0].tolist(),
|
||
'Y': target.m[1].tolist(),
|
||
'Z': target.m[2].tolist(),
|
||
'Vx': target.v[0].tolist(),
|
||
'Vy': target.v[1].tolist(),
|
||
'Vz': target.v[2].tolist(),
|
||
'born_time': str(target.born_time)
|
||
}
|
||
return out_former
|
||
|
||
|
||
def run(self, sensor_data):
|
||
current_time = datetime.datetime.now()
|
||
run_timestamp = int(current_time.timestamp() * 1000)
|
||
fusion_data = []
|
||
selected_list = []
|
||
self.obtain_priority(sensor_data)
|
||
|
||
# 遍历所有已知对象
|
||
for target in self.targets:
|
||
print(f"Fusion target id:{target.id} with survive: {target.survive} at :{target.m}\n")
|
||
if target.survive < self.existence_thres:
|
||
continue
|
||
target.predict()
|
||
res = target.correlation(sensor_data)
|
||
target.update(res,run_timestamp,self.gate)
|
||
# ==================================================
|
||
now_id = []
|
||
t_sum = 0
|
||
for r, distance in enumerate(res['distances']):
|
||
if distance < self.gate:
|
||
now_id.append(res['key_ids'][r])
|
||
selected_list.append(res['key_ids'][r])
|
||
D_Id = res['key_ids'][r][0]
|
||
t_sum += self.priority_dict[D_Id]
|
||
target.source = now_id
|
||
# ==================================================
|
||
if self.fusion_type == 2 and t_sum < 2:
|
||
target.survive = target.survive * 0.5
|
||
|
||
out_former = self.out_transformer(target)
|
||
if target.survive > self.show_thres: # 若存活概率大于0.4,则写入数据文件
|
||
fusion_data.append(out_former)
|
||
|
||
# 根据匹配关系筛选数值
|
||
self.selected_list = selected_list
|
||
for data in sensor_data:
|
||
self.new_born(data)
|
||
|
||
self.remove_duplicates()
|
||
# ==================================================
|
||
self.fusion_process_log(fusion_data)
|
||
|
||
return fusion_data
|
||
|
||
def new_born(self,value,):
|
||
for j, id in enumerate(value['id_list']):
|
||
key = value['deviceId']
|
||
if [key, id] not in self.selected_list:
|
||
if self.fusion_type == 3:
|
||
if value['priority'] > 50:
|
||
self.targets.append(KalmanFilter(value['data_matrix'][j, :], [key, id],self.measuremrnt_variance,self.process_variance))
|
||
else:
|
||
self.targets.append(KalmanFilter(value['data_matrix'][j, :], [key, id],self.measuremrnt_variance,self.process_variance))
|
||
self.selected_list.append([key, id]) # 把新增的目标,加入到集合中去
|
||
|
||
def remove_duplicates(self):
|
||
# 创建一个空列表用于存储需要删除的列表的索引
|
||
to_delete = []
|
||
|
||
# 遍历所有列表的索引
|
||
for i in range(len(self.targets)):
|
||
if self.targets[i].survive < self.existence_thres:
|
||
to_delete.append(self.targets[i].id)
|
||
continue
|
||
if self.targets[i].survive < self.show_thres:
|
||
continue
|
||
for j in range(i + 1, len(self.targets)):
|
||
# 比较两个列表是否相同
|
||
if are_lists_equal(self.targets[i].source, self.targets[j].source):
|
||
# 如果列表相同,记录编号较大的索引
|
||
if self.targets[i].duration < self.targets[j].duration:
|
||
to_delete.append(self.targets[i].id)
|
||
else:
|
||
to_delete.append(self.targets[j].id)
|
||
|
||
# 使用删除法,提高目标管理效率
|
||
for item_id in sorted(to_delete, reverse=True):
|
||
for target in self.targets:
|
||
if target.id == item_id:
|
||
self.targets.remove(target)
|
||
break
|
||
|
||
def fusion_process_log(self,fusion_data):
|
||
current_time = datetime.datetime.now()
|
||
# 格式化时间为年月日时分秒格式
|
||
formatted_time = current_time.strftime('%Y-%m-%d %H:%M:%S')
|
||
with open('process_log.txt', 'a') as log_file: # 以追加模式打开日志文件
|
||
log_file.write('=====================\n') # 写入分隔符
|
||
log_file.write(f"time: {formatted_time}\n") # 写入分隔符
|
||
log_file.write(f"data:\n {fusion_data}\n") # 写入消息内容
|