自动驾驶云服务 OCTOPUS-数据脱敏作业:示例代码

时间:2024-09-06 18:25:54

示例代码

下面是rosbag脱敏的算子示例:

# mask.py
import json
import logging
import multiprocessing as mp
import os
import shutil
import time
from pathlib import Path
from typing import cast

import av
import numpy as np
import open3d
from rosbags.highlevel import AnyReader
from rosbags.interfaces import ConnectionExtRosbag1, ConnectionExtRosbag2
from rosbags.rosbag1 import Writer as Writer1
from rosbags.rosbag2 import Writer as Writer2
from rosbags.serde import cdr_to_ros1, serialize_cdr
from rosbags.typesys import get_types_from_msg, register_types
from rosbags.typesys.types import builtin_interfaces__msg__Time as Time
from rosbags.typesys.types import \
    sensor_msgs__msg__CompressedImage as CompressedImage
from rosbags.typesys.types import std_msgs__msg__Header as Header

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
)
 LOG  = logging.getLogger(__file__)

# Octopus数据服务拉起镜像时灌入的环境变量
# 获取环境变量
input_path = os.getenv('input_path', 'data/hangyan-move.bag.bak')
raw_dir = os.getenv('raw_dir', 'empty_dir/raw')  # 抽取的文件存放目录
desens_dir = os.getenv('desensitized_dir', 'empty_dir/desens')  # 脱敏后的文件存放目录
output_dir = os.getenv('output_dir', 'empty_dir/output')
lidar_process_num = os.getenv('lidar_process_num', 5)  # lidar数据进程数

# 用户自定义环境变量
rosbag_version = os.getenv('rosbag_version', '1')  # rosbag版本,取值为'1'或'2'
image_topics = [
    x.strip(' ')
    for x in os.getenv('image_topics', '/camera_encoded_1').split(',')
]  # 图像数据的topic列表
gnss_topic = os.getenv('gnss_topic',
                       '/inspvax')  # gnss数据的topic,gnss数据只能有一个topic
lidar_topics = [
    x.strip(' ') for x in os.getenv('lidar_topic', '/pandar').split(',')
]  # 点云数据的topic列表

# 注册自定义消息类型
Video_encoded_data_text = Path('msgs/Video_encoded_data.msg').read_text()
NovatelMessageHeader_text = Path('msgs/NovatelMessageHeader.msg').read_text()
NovatelExtendedSolutionStatus_text = Path(
    'msgs/NovatelExtendedSolutionStatus.msg').read_text()
NovatelReceiverStatus_text = Path('msgs/NovatelReceiverStatus.msg').read_text()
Inspvax_text = Path('msgs/Inspvax.msg').read_text()
add_types = {}
add_types.update(
    get_types_from_msg(
        Video_encoded_data_text,
        'kyber_msgs/msg/Video_encoded_data',
    ))
add_types.update(
    get_types_from_msg(
        NovatelMessageHeader_text,
        'novatel_gps_msgs/msg/NovatelMessageHeader',
    ))
add_types.update(
    get_types_from_msg(
        NovatelExtendedSolutionStatus_text,
        'novatel_gps_msgs/msg/NovatelExtendedSolutionStatus',
    ))
add_types.update(
    get_types_from_msg(
        NovatelReceiverStatus_text,
        'novatel_gps_msgs/msg/NovatelReceiverStatus',
    ))
add_types.update(
    get_types_from_msg(
        Inspvax_text,
        'novatel_gps_msgs/msg/Inspvax',
    ))
register_types(add_types)

# create gnss file
gnss_file_path = Path(raw_dir, 'gnss') / f'{gnss_topic}.json'.strip('/')
Path.mkdir(gnss_file_path.parent, parents=True, exist_ok=True)


def extract_image(input_rosbag):
    '''从原始rosbag中抽取图像数据.'''
    LOG.info('Start extracting image.')
    codec_ctx = av.codec.Codec('hevc', 'r')
    h265_code = codec_ctx.create()
    with AnyReader([Path(input_rosbag)]) as reader:
        for connection, timestamp, data in reader.messages():
            topic = connection.topic
            if topic in image_topics:
                deserialized_data = reader.deserialize(data,
                                                       connection.msgtype)
                try:
                    data = deserialized_data.raw_data
                    packet = av.packet.Packet(data)
                    out = h265_code.decode(packet)
                    img = None
                    for frame in out:
                        if frame.format.name != 'rgb24':
                            frame = frame.reformat(format='rgb24')
                        img = frame.to_image()
                    # 图像存放路径
                    file_name = f'{timestamp}.jpg'
                    f_path = Path(raw_dir, 'image') / topic.strip('/')
                    tmp_path = Path(raw_dir, 'tmp_image') / topic.strip('/')
                    Path.mkdir(tmp_path, parents=True, exist_ok=True)
                    tmp_file = tmp_path / file_name
                    file = f_path / file_name
                    # 当未建立目录时,先基于topic名称建立目录
                    Path.mkdir(file.parent, parents=True, exist_ok=True)
                    img.save(tmp_file)
                    os.chmod(tmp_file, 0o777)
                    os.chmod(file.parent, 0o777)
                    shutil.move(tmp_file, file)
                except Exception as e:
                    LOG.info("%s frame can not trans to jpg, message: %s",
                             timestamp, str(e))
    LOG.info('Finish extracting image.')


def extract_lidar(task_id, task_num, input_rosbag):
    '''从原始rosbag中抽取点云数据.'''
    LOG.info('Start extracting pcd.')
    with AnyReader([Path(input_rosbag)]) as reader:
        for i, (connection, timestamp, data) in enumerate(reader.messages()):
            if i % task_num != task_id:
                continue
            topic = connection.topic
            if topic in lidar_topics:
                deserialized_data = reader.deserialize(data,
                                                       connection.msgtype)
                pcd = open3d.geometry.PointCloud()
                reshaped = deserialized_data.data.reshape(
                    int(len(deserialized_data.data) / 3), 3)
                pcd.points = open3d.utility.Vector3dVector(reshaped)

                file_name = f'{timestamp}.pcd'
                f_path = Path(raw_dir, 'lidar') / topic.strip('/')
                tmp_path = Path(raw_dir, 'tmp_lidar') / topic.strip('/')
                Path.mkdir(tmp_path, parents=True, exist_ok=True)
                tmp_file = tmp_path / file_name
                file = f_path / file_name
                # 当未建立目录时,先基于topic名称建立目录
                Path.mkdir(file.parent, parents=True, exist_ok=True)
                open3d.io.write_point_cloud(str(tmp_file), pcd)
                os.chmod(tmp_file, 0o777)
                os.chmod(file.parent, 0o777)
                shutil.move(tmp_file, file)
    LOG.info('Finish extracting pcd.')


def extract_gnss(input_rosbag):
    '''从原始rosbag中抽取gnss数据.'''
    LOG.info('Start extracting rosbag.')
    gnss_file = open(gnss_file_path, 'w')
    gnss = dict()
    with AnyReader([Path(input_rosbag)]) as reader:
        for connection, timestamp, data in reader.messages():
            topic = connection.topic
            if topic == gnss_topic:
                deserialized_data = reader.deserialize(data,
                                                       connection.msgtype)
                # 这里以msgytpe为NavSatFix为例
                latitude = deserialized_data.latitude
                longitude = deserialized_data.longitude
                altitude = deserialized_data.altitude
                gnss[timestamp] = {
                    'latitude': latitude,
                    'longitude': longitude,
                    'altitude': altitude
                }
    gnss_file.write(json.dumps(gnss))
    gnss_file.close()
    LOG.info('Finish extracting gnss.')


def _get_masked_image(topic, timestamp):
    '''从脱敏后的图像数据中获取目标图像数据.'''
    file = Path(desens_dir, 'image') / topic.strip('/') / f'{timestamp}.jpg'
    if file.is_file():
        return np.fromfile(file, dtype='uint8')
    else:
        return None


def _get_conn_map(rosbag_version: int, reader, writer):
    '''构建connection的索引.'''
    conn_map = {}
    if rosbag_version == '1':
        for conn in reader.connections:
            if conn.topic in image_topics:
                conn_map[conn.id] = writer.add_connection(
                    '/image',
                    CompressedImage.__msgtype__,
                )
            else:
                ext = cast(ConnectionExtRosbag1, conn.ext)
                conn_map[conn.id] = writer.add_connection(
                    conn.topic,
                    conn.msgtype,
                    conn.msgdef,
                    conn.md5sum,
                    ext.callerid,
                    ext.latching,
                )
    elif rosbag_version == '2':
        for conn in reader.connections:
            if conn.topic in image_topics:
                conn_map[conn.id] = writer.add_connection(
                    '/image',
                    CompressedImage.__msgtype__,
                )
            else:
                ext = cast(ConnectionExtRosbag2, conn.ext)
                conn_map[conn.id] = writer.add_connection(
                    conn.topic,
                    conn.msgtype,
                    ext.serialization_format,
                    ext.offered_qos_profiles,
                )
    return conn_map


def _serialize_data(rosbag_version, data, msgtype):
    '''对数据进行序列化.'''
    if rosbag_version == '1':
        return cdr_to_ros1(serialize_cdr(data, msgtype), msgtype)
    elif rosbag_version == '2':
        return serialize_cdr(data, msgtype)


def generate_rosbag(input_rosbag, output_rosbag):
    '''生成脱敏后rosbag.'''
    LOG.info('Start generating rosbag.')
    gnss_file = open(
        Path(desens_dir, 'gnss') / f'{gnss_topic}.json'.strip('/'), 'r')
    gnss_data = json.load(gnss_file)
    gnss_file.close()
    Writer = Writer1 if rosbag_version == '1' else Writer2
    with AnyReader([Path(input_rosbag)
                    ]) as reader, Writer(Path(output_rosbag)) as writer:
        conn_map = _get_conn_map(rosbag_version, reader, writer)
        for connection, timestamp, data in reader.messages():
            topic = connection.topic
            # 当topic为图像数据的topic时,读取脱敏后图像数据
            if topic in image_topics:
                masked_data = _get_masked_image(topic, timestamp)
                if masked_data is None:  # 没有解析出图像文件时,不要该帧了
                    continue
                deserialized_data = CompressedImage(
                    Header(
                        stamp=Time(
                            sec=int(timestamp // 10**9),
                            nanosec=int(timestamp % 10**9),
                        ),
                        frame_id='0',
                    ),
                    format='jpg',
                    data=masked_data,
                )
                data = _serialize_data(
                    rosbag_version,
                    deserialized_data,
                    CompressedImage.__msgtype__,
                )
            # 当topic为gnss数据时,读取脱敏后gnss数据
            elif topic == gnss_topic:
                deserialized_data = reader.deserialize(data,
                                                       connection.msgtype)
                deserialized_data.latitude = gnss_data.get(
                    str(timestamp)).get('latitude')
                deserialized_data.longitude = gnss_data.get(
                    str(timestamp)).get('longitude')
                deserialized_data.altitude = gnss_data.get(
                    str(timestamp)).get('altitude')
                data = _serialize_data(
                    rosbag_version,
                    deserialized_data,
                    connection.msgtype,
                )
            # 当topic为点云数据时,读取脱敏后点云数据
            elif topic in lidar_topics:
                deserialized_data = reader.deserialize(
                    data,
                    connection.msgtype,
                )
                file = Path(
                    desens_dir,
                    'lidar',
                ) / topic.strip('/') / f'{timestamp}.pcd'
                point_cloud = open3d.io.read_point_cloud(str(file))
                deserialized_data.data = np.asarray(
                    point_cloud.points).flatten()
            writer.write(conn_map[connection.id], timestamp, data)
    # 生成_SUC CES S文件标识完成数据抽取
    Path(output_dir, '_SUCCESS').touch()
    LOG.info('Finish generating rosbag.')


if __name__ == "__main__":
    LOG.info('Start user operator.')
    process_image = mp.Process(target=extract_image, args=(input_path, ))
    pool_lidar = mp.Pool(processes=lidar_process_num)
    for i in range(lidar_process_num):
        pool_lidar.apply_async(extract_lidar,
                               args=(i, lidar_process_num, input_path))
    process_gnss = mp.Process(target=extract_gnss, args=(input_path, ))
    # 启动子进程
    process_image.start()
    pool_lidar.close()
    process_gnss.start()
    process_image.join()
    pool_lidar.join()
    process_gnss.join()
    LOG.info('Child processes exit.')
    # 生成_SUCCESS文件标识完成数据抽取
    Path(raw_dir, '_SUCCESS').touch()

    # 后面输出的rosbag文件与输入的rosbag文件保持同名
    output_rosbag_file = Path(output_dir, Path(input_path).name)
    # 如果输出文件夹不存在,先创建文件夹
    Path.mkdir(output_rosbag_file.parent, parents=True, exist_ok=True)
    # 检测到脱敏任务结束后,生成新的rosbag文件
    while time.sleep(1) is None:
        if Path(desens_dir).joinpath('_SUCCESS').is_file():
            generate_rosbag(Path(input_path), output_rosbag_file)
            break
support.huaweicloud.com/usermanual-octopus/octopus-15-0008.html