返回列表 发布新帖
查看: 720|回复: 9

让spark机器FOLLOW

11

主题

17

回帖

231

积分

版主

积分
231
发表于 2024-7-5 19:05:01 | 查看全部 |阅读模式
这段代码是一个ROS(Robot Operating System)节点,主要用于检测图像中的物体,并结合机器人跟随功能。以下是代码的主要部分和功能:
1. 导入所需的库,包括TensorFlow、OpenCV、ROS消息类型等。
2. 设置模型和标签地图的路径。这里使用了SSD(Single Shot MultiBox Detector)MobileNet模型,针对COCO(Common Objects in Context)数据集训练。
3. 定义一个`detector`类,该类包含一个图像回调函数`image_cb`,用于处理订阅的图像消息。在回调函数中,使用TensorFlow模型对图像进行物体检测,并将检测结果发布出去。
4. 在`image_cb`函数中,首先将ROS图像消息转换为OpenCV图像,然后使用TensorFlow模型进行物体检测,最后将检测结果发布出去。
5. `object_predict`函数用于根据TensorFlow模型的输出预测物体的类别和位置。
6. `main`函数初始化ROS节点,并创建一个`detector`对象来处理图像消息。
7. 最后,定义了一个ROS launch文件,用于启动该节点和其他相关节点。
此外,代码还包括了机器人跟随功能的部分参考代码,如`cal_center`函数用于计算物体中心的坐标,`check_inbox`函数用于检查物体是否在特定的区域内,`start_following`函数用于根据检测到的物体位置控制机器人的运动。

这个节点可以用于机器人视觉导航、物体识别和跟随等应用。它订阅了一个图像话题,使用TensorFlow模型进行物体检测,并将检测结果发布到多个话题。物体检测结果包括物体的类别、位置和置信度分数。


#!/usr/bin/env python
## Author: Rohit
## Date: July, 25, 2017
# Purpose: Ros node to detect objects using tensorflow
# 加载必要的功能函数,包括tensorflow       
import os
import sys
import cv2
import numpy as np
try:
    import tensorflow as tf
except ImportError:
    print("unable to import TensorFlow. Is it installed?")
    print("  sudo apt install python-pip")
    print("  sudo pip install tensorflow")
    sys.exit(1)
# 加载ros相关message文件
import rospy
from std_msgs.msg import String , Header
from sensor_msgs.msg import Image
from cv_bridge import CvBridge, CvBridgeError
from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose
from dl_msgs.msg import DetectionArray, ObjectInfo
# 加载物体识别的功能包
import object_detection
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
# 设置GPU占用率如果利用GPU进行学习
GPU_FRACTION = 0.4
######### 设置学习的模型 ############
MODEL_NAME =  'ssd_mobilenet_v1_coco_11_06_2017'
# By default models are stored in data/models/
MODEL_PATH = os.path.join(os.path.dirname(sys.path[0]),'data','models' , MODEL_NAME)
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_PATH + '/frozen_inference_graph.pb'
######### 设置物体标签地图 ###########
LABEL_NAME = 'mscoco_label_map.pbtxt'
# By default label maps are stored in data/labels/
PATH_TO_LABELS = os.path.join(os.path.dirname(sys.path[0]),'data','labels', LABEL_NAME)
######### Set the number of classes here #########
NUM_CLASSES = 90
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')
## 加载标签地图
# 标签地图通过序号索引到物体名称,所以当卷积网络预测到数字5时,我们知道是对应着’airplane’
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
# 设置 GPU 选项来 使用我们之前设置的GPU占用率
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = GPU_FRACTION
# 开始识别
with detection_graph.as_default():
  with tf.Session(graph=detection_graph,config=config) as sess:
    class detector:
      def __init__(self):
        self.image_pub = rospy.Publisher("debug_image",Image, queue_size=1)
        self.object_pub = rospy.Publisher("objects", Detection2DArray, queue_size=1)
                    self.obj_info_pub = rospy.Publisher("objects_info", ObjectInfo, queue_size=1)
        self.objectinfo_list_pub = rospy.Publisher("object_info_list", DetectionArray, queue_size=1)
        self.bridge = CvBridge()
        self.image_sub = rospy.Subscriber("image", Image, self.image_cb, queue_size=1, buff_size=2**24)
#图像回调
      def image_cb(self, data):
        objArray = Detection2DArray()
        objDetection = DetectionArray()
        obj_info = ObjectInfo()
        try:
          cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
        except CvBridgeError as e:
          print(e)
        image=cv2.cvtColor(cv_image,cv2.COLOR_BGR2RGB)
       # 基于数组的图像会在后面被用作带有方框和标签的图像结果
        image_np = np.asarray(image)
        # 把维度展开因为模型认为图像有格式[1,None,None,3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # 每个方框boxes都代表着被识别到的物体的范围
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # 每个分数score都代表着对识别到的物体的精确度,分数是与方框和名字一起呈现在屏幕上的.
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')
        (boxes, scores, classes, num_detections) = sess.run([boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        objects=vis_util.visualize_boxes_and_labels_on_image_array(
            image,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=2)
        # 定义识别到的物体信息
        objArray.detections =[]
        obj_info = ObjectInfo()
        objDetection.objectinfo = []
      
        objArray.header=data.header
        object_count=1
        for i in range(len(objects)):
          object_count+=1
          obj, obj_info = self.object_predict(objects,data.header,image_np,cv_image)
          objArray.detections.append(obj)
          objDetection.objectinfo.append(obj_info)
        self.obj_info_pub.publish(obj_info)
        self.objectinfo_list_pub.publish(objDetection)
        self.object_pub.publish(objArray)
        img=cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
        image_out = Image()
        try:
          image_out = self.bridge.cv2_to_imgmsg(img,"bgr8")
        except CvBridgeError as e:
          print(e)
        image_out.header = data.header
        self.image_pub.publish(image_out)
      def object_predict(self,object_data, header, image_np,image):
        image_height,image_width,channels = image.shape
        obj=Detection2D()
        obj_info=ObjectInfo()
        obj_hypothesis= ObjectHypothesisWithPose()
        object_id=object_data[0]
        object_score=object_data[1]
        dimensions=object_data[2]
                 object_name=category_index[object_id][‘name’]
        obj.header=header
        obj_hypothesis.id = object_id
        obj_hypothesis.score = object_score
        obj.results.append(obj_hypothesis)
        obj.bbox.size_y = int((dimensions[2]-dimensions[0])*image_height)
        obj.bbox.size_x = int((dimensions[3]-dimensions[1])*image_width)
        obj.bbox.center.x = int((dimensions[1] + dimensions[3])*image_height/2)
        obj.bbox.center.y = int((dimensions[0] + dimensions[2])*image_width/2)
def main(args):
  rospy.init_node('detector_node')
  obj=detector()
  try:
    rospy.spin()
  except KeyboardInterrupt:
    print("ShutDown")
  cv2.destroyAllWindows()
if __name__=='__main__':
  main(sys.argv)

def cal_center(self, box):
        xc = int((box[0] + box[2]) / 2)
        yc = int((box[0] + box[2]) / 2)
        return xc, yc
    # 检查是否在原位
    def check_inbox(self):
        global xc, yc, xc_prev, yc_prev
        inbox = False
        if abs(xc_prev - xc) < 40 and abs(yc_prev - yc) < 40:
            xc_prev, yc_prev = xc, yc
            inbox = True
        return inbox
        # 跟随代码
    def start_following(self, pointcloud, x_range, y_range, xc):
        cmd_pub = Twist()
        x = 0
        z = 0
        n = 0
        p = []
        # get the range of detected box
        for i in x_range:
            for j in y_range:
                pc = point_cloud2.read_points(pointcloud, field_names=("x","y","z"), skip_nans=True, uvs=[[i, j]])
                for p in pc:
                    if p[2] <= 2.5:
                        x += -p[0]
                        z += p[2]
                        n += 1
        if not p or n == 0:
            x_linear = 0
            z_angular = 0
        elif abs(xc - IMAGE_WIDTH / 2) < 20:
            z /= n
            x_linear = z - GOAL_DEPTH
            z_angular = 0
        else:
            x /= n
            z /= n
            rospy.loginfo("n: " + str(n) + ",x: " + str(x) + ",z: " + str(z))
            dist = math.sqrt(x * x + z * z)
            x_linear = z - GOAL_DEPTH
            z_angular = math.asin(x / dist)
            if abs(x_linear) < DEPTH_THRESHOLD:
                x_linear = 0
            elif x_linear > 1.2:
                x_linear = 1
            if z_angular > TURN_THRESHOLD:
                z_angular = TURN_THRESHOLD
            elif z_angular < -TURN_THRESHOLD:
                z_angular = -TURN_THRESHOLD
        cmd_pub.linear.x = x_linear
        cmd_pub.angular.z = z_angular
        rospy.loginfo(cmd_pub)
        self.pub_cmd.publish(cmd_pub)
        rospy.sleep(5)
        #发布消息
    def publish(self, detections, image_outgoing):
        """
        - publish detection information and drawn image
        """
        self.pub_det.publish(detections)
        self.pub_det_rgb.publish(image_outgoing)

设置启动文件
!-- 启动文件 第三课:人脸识别主人跟随 -->
<launch>
        <arg name="camera_types"        default="astrapro"/>
         <!--spark 驱动,机器人描述,相机,底盘-->
          <include file="$(find spark_bringup)/launch/driver_bringup.launch">
                  <arg name="camera_type_tel"        value="$(arg camera_types)" />
          </include>
        <node pkg= "tensorflow_object_detector" name="detect_ros" type="detect_ros_DLunit3.py"  output="screen">
                    <remap from="image" to="/camera/color/image_raw" if="$(eval arg('camera_types')=='d435')"/>
                <remap from="image" to="/camera/rgb/image_raw" if="$(eval arg('camera_types')=='astrapro')"/>
        </node>
        <arg name="master_name" value=”any_name”/>
        <node pkg="unit3_follow" name="face_recognition_node" type="face_recognizer.py" output="screen">
                <param name="master_name"  value="$(arg master_name)"/>
                <remap from="image" to="/camera/rgb/image_raw" />
        </node>
        <!-- rviz -->
          <arg name ="rviz" default="true" />
          <arg name ="rviz_file" default="$(find unit3_follow)/rviz/display.rviz"/>
        <node pkg ="rviz" type="rviz" name="rviz" output="screen" args= "-d $(arg rviz_file)" if="$(arg rviz)"/>
</launch>   


仅供参考,spark机器人跟随的部分参考代码

回复

举报

0

主题

11

回帖

36

积分

新手上路

积分
36
发表于 2024-7-6 17:42:11 | 查看全部
牛牛牛,很有创新
回复

举报

0

主题

7

回帖

18

积分

新手上路

积分
18
发表于 2024-7-6 17:47:43 | 查看全部
加油很有参考意义
回复

举报

11

主题

17

回帖

231

积分

版主

积分
231
 楼主| 发表于 2024-7-6 21:06:16 | 查看全部
刘子康 发表于 2024-7-6 17:42
牛牛牛,很有创新

谢谢
回复

举报

11

主题

17

回帖

231

积分

版主

积分
231
 楼主| 发表于 2024-7-6 21:06:46 | 查看全部
陈厚树 发表于 2024-7-6 17:47
加油很有参考意义

谢谢
回复

举报

0

主题

8

回帖

30

积分

新手上路

积分
30
发表于 2024-7-6 21:46:34 | 查看全部
很新颖的观点
回复

举报

0

主题

10

回帖

28

积分

新手上路

积分
28
发表于 2024-7-13 15:24:13 | 查看全部
继续加油,很有帮助
回复

举报

0

主题

11

回帖

30

积分

新手上路

积分
30
发表于 2024-7-13 15:31:50 | 查看全部
这个功能很新颖
回复

举报

0

主题

12

回帖

40

积分

新手上路

积分
40
发表于 2024-7-13 15:54:03 | 查看全部
这个功能很有趣
回复

举报

0

主题

10

回帖

36

积分

新手上路

积分
36
发表于 2024-7-13 15:57:59 | 查看全部
很有创新和有趣的功能
回复

举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关灯 在本版发帖
扫一扫添加微信客服
返回顶部
快速回复 返回顶部 返回列表