Files
ProjectOctopus/agent-common/SplitProject/ranjing-python-devfusion/KF_V2.py
2025-03-05 14:46:36 +08:00

280 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
from os import error
import numpy as np
from config import *
def calculate_euclidean_distances(A, B):
# 计算A和B之间的欧式距离
distances = np.linalg.norm(A - B, axis=1)
# 找到最小距离及其索引
min_distance_index = np.argmin(distances)
min_distance = distances[min_distance_index]
return min_distance, min_distance_index
def are_lists_equal(listA, listB):
# 对两个列表中的子列表进行排序
if len(listA) == 0:
return False
sorted_listA = sorted(listA, key=lambda x: (x[0], x[1]))
sorted_listB = sorted(listB, key=lambda x: (x[0], x[1]))
# 比较排序后的列表是否相等
return sorted_listA == sorted_listB
def sigmoid(x, a=10, b=0.1):
# 调整Sigmoid函数使其在x=1时值为0.5
# a和b是调整参数用于控制函数的形状
return 1 / (1 + np.exp(-a * (x - shift_value))) + b
class KalmanFilter:
def __init__(self, measurement, com_id, measurement_variance=1,process_variance=1e-1):
current_time = datetime.datetime.now()
timestamp = int(current_time.timestamp() * 1000000)
ms = measurement.tolist()
self.m = np.array([ms[0],ms[1],ms[2],0,0,0]) # 状态量,6维度
self.origin = [com_id] #origin 表示最强响应
self.source = self.origin #source 表示所有关联的观测值
self.survive = np.array(survive_initial) # 初始化生存值
self.duration = 0
self.counter = 0
self.id = str(timestamp % 3600000000 + np.random.randint(1000))
self.F = [[1,0,0,1,0,0],
[0,1,0,0,1,0],
[0,0,1,0,0,1],
[0,0,0,1,0,0],
[0,0,0,0,1,0],
[0,0,0,0,0,1]]
self.F = np.array(self.F)
self.H = [[1,0,0,0,0,0],
[0,1,0,0,0,0],
[0,0,1,0,0,0]]
self.H = np.array(self.H)
self.R = measurement_variance * np.eye(3)
self.Q = process_variance * np.eye(6)
self.Q[3, 3] = self.Q[3, 3] * 1e-3
self.Q[4, 4] = self.Q[4, 4] * 1e-3
self.Q[5, 5] = self.Q[5, 5] * 1e-3
self.P = np.eye(6)*0.1
self.I = np.eye(6)
self.expend = 1
self.v = np.array([0,0,0])
self.born_time = int(current_time.timestamp() * 1000)
self.latest_update = self.born_time
self.m_history = self.m
self.s_history = []
self.origin_set = [self.origin]
def predict(self):
F = self.F
self.m = np.dot(F,self.m.T) # 简单一步预测模型
self.m = self.m.T
self.P = np.dot(np.dot(F,self.P),F.T) + self.Q
self.survive = self.survive * decay # 应用衰减值
self.origin_set = np.unique(np.array(self.origin_set), axis=0).tolist() # 计算关联集合
def update(self, res, run_timestamp, gate):
self.duration += 0.6 # 每次更新时,持续时间+0.6
if len(res['distances']) == 0:
mmd = 1e8
else:
min_distance_index = np.argmin(res['distances'])
mmd = res['distances'][min_distance_index]
measurement = res['measurements'][min_distance_index]
# 进行更新
if mmd < gate * self.expend:
H = self.H
I = self.I
self.expend = max(self.expend * 0.8, 1)
kalman_gain = np.dot(np.dot(self.P,H.T),np.linalg.pinv(np.dot(np.dot(H,self.P),H.T)+self.R))
self.m += np.dot(kalman_gain,(measurement.T - np.dot(H,self.m.T)))
self.m = self.m.T
self.P = np.dot((I - np.dot(kalman_gain,H)),self.P)
self.origin = [res['key_ids'][min_distance_index]]
self.counter += 1
self.survive = sigmoid(self.counter) # 新映射函数
# 如下操作防止对速度过于自信
self.P[3, 3] = max(1e-1, self.P[3, 3])
self.P[4, 4] = max(1e-1, self.P[4, 4])
self.P[5, 5] = max(1e-1, self.P[5, 5])
# 截取速度
self.v = self.m[3:6]
self.origin_set.append(self.origin)
self.latest_update = run_timestamp #对时间进行处理
else:
self.expend = min(self.expend*1.2,1.5) # 若关联不上,则扩大门限继续搜索
self.P[3, 3] = min(self.P[3, 3]*1.1,1)
self.P[4, 4] = min(self.P[4, 4]*1.1,1)
self.P[5, 5] = min(self.P[5, 5]*1.1,1)
self.counter -= 1
self.counter = max(self.counter,0)
self.m_history = np.vstack((self.m_history, self.m))
self.s_history.append(self.survive)
def one_correlation(self, data_matrix, id_list):
# 计算现有数据与data_matrix的差距
min_distance, min_index = calculate_euclidean_distances(self.m[0:3], data_matrix)
m_id = id_list[min_index]
measurement = data_matrix[min_index, :]
return m_id, min_distance, measurement
def correlation(self, sensor_data):
# 遍历传感器进行计算
res = {'m_ids':[], 'distances':[], 'measurements':[], 'key_ids':[]}
for value in sensor_data:
if len(value['id_list']) > 0:
m_id, min_distance, measurement = self.one_correlation(value['data_matrix'], value['id_list'])
key = value['deviceId']
res['m_ids'].append(m_id)
res['measurements'].append(measurement)
res['key_ids'].append([key, m_id])
# 将发生过关联的目标赋予更大的置信度
if [key, m_id] in self.origin_set:
min_distance = min_distance * 0.2
res['distances'].append(min_distance)
return res
#融合类的构造函数
class DataFusion:
def __init__(self,gate=25,interval = 1,fusion_type = 1,
measuremrnt_variance=1,process_variance =1e-1):
"""
初始化DataFusion类。
"""
# self.task_id = task_id
self.interval = interval
self.gate = gate
self.targets = []
self.fusion_type = fusion_type
self.existence_thres = 0.01
self.show_thres = show_thres
self.process_variance = process_variance
self.measuremrnt_variance = measuremrnt_variance
def set_parameter(self,fusion_parms):
print("GO!!!!!!!!!")
print(fusion_parms)
def obtain_priority(self,sensor_data):
self.priority_dict = dict()
for data in sensor_data:
if data.get('priority'):
self.priority_dict[data['deviceId']] = data['priority']
else:
self.priority_dict[data['deviceId']] = 1
def out_transformer(self,target):
out_former = {
'objectId': target.id,
'survive': target.survive.tolist(),
'state': target.m.tolist(),
'speed': np.linalg.norm(target.v).tolist() / self.interval,
'source': target.source,
'sigma': np.diag(target.P).tolist(),
'X': target.m[0].tolist(),
'Y': target.m[1].tolist(),
'Z': target.m[2].tolist(),
'Vx': target.v[0].tolist(),
'Vy': target.v[1].tolist(),
'Vz': target.v[2].tolist(),
'born_time': str(target.born_time)
}
return out_former
def run(self, sensor_data):
current_time = datetime.datetime.now()
run_timestamp = int(current_time.timestamp() * 1000)
fusion_data = []
selected_list = []
self.obtain_priority(sensor_data)
# 遍历所有已知对象
for target in self.targets:
print(f"Fusion target id:{target.id} with survive: {target.survive} at :{target.m}\n")
if target.survive < self.existence_thres:
continue
target.predict()
res = target.correlation(sensor_data)
target.update(res,run_timestamp,self.gate)
# ==================================================
now_id = []
t_sum = 0
for r, distance in enumerate(res['distances']):
if distance < self.gate:
now_id.append(res['key_ids'][r])
selected_list.append(res['key_ids'][r])
D_Id = res['key_ids'][r][0]
t_sum += self.priority_dict[D_Id]
target.source = now_id
# ==================================================
if self.fusion_type == 2 and t_sum < 2:
target.survive = target.survive * 0.5
out_former = self.out_transformer(target)
if target.survive > self.show_thres: # 若存活概率大于0.4,则写入数据文件
fusion_data.append(out_former)
# 根据匹配关系筛选数值
self.selected_list = selected_list
for data in sensor_data:
self.new_born(data)
self.remove_duplicates()
# ==================================================
self.fusion_process_log(fusion_data)
return fusion_data
def new_born(self,value,):
for j, id in enumerate(value['id_list']):
key = value['deviceId']
if [key, id] not in self.selected_list:
if self.fusion_type == 3:
if value['priority'] > 50:
self.targets.append(KalmanFilter(value['data_matrix'][j, :], [key, id],self.measuremrnt_variance,self.process_variance))
else:
self.targets.append(KalmanFilter(value['data_matrix'][j, :], [key, id],self.measuremrnt_variance,self.process_variance))
self.selected_list.append([key, id]) # 把新增的目标,加入到集合中去
def remove_duplicates(self):
# 创建一个空列表用于存储需要删除的列表的索引
to_delete = []
# 遍历所有列表的索引
for i in range(len(self.targets)):
if self.targets[i].survive < self.existence_thres:
to_delete.append(self.targets[i].id)
continue
if self.targets[i].survive < self.show_thres:
continue
for j in range(i + 1, len(self.targets)):
# 比较两个列表是否相同
if are_lists_equal(self.targets[i].source, self.targets[j].source):
# 如果列表相同,记录编号较大的索引
if self.targets[i].duration < self.targets[j].duration:
to_delete.append(self.targets[i].id)
else:
to_delete.append(self.targets[j].id)
# 使用删除法,提高目标管理效率
for item_id in sorted(to_delete, reverse=True):
for target in self.targets:
if target.id == item_id:
self.targets.remove(target)
break
def fusion_process_log(self,fusion_data):
current_time = datetime.datetime.now()
# 格式化时间为年月日时分秒格式
formatted_time = current_time.strftime('%Y-%m-%d %H:%M:%S')
with open('process_log.txt', 'a') as log_file: # 以追加模式打开日志文件
log_file.write('=====================\n') # 写入分隔符
log_file.write(f"time: {formatted_time}\n") # 写入分隔符
log_file.write(f"data:\n {fusion_data}\n") # 写入消息内容