머신러닝

[Tensorflow Object Detection API] custom data로 재학습(retraining) 시키기 3 - training

skkim1080 2020. 5. 13. 17:28

 

1. pre-trained model 다운로드 하기

 

저는 ssd_mobilenet_v1_coco_11_06_2017 모델을 다운받았습니다. 

 

object_detection> wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz

 

 

2. config파일 수정

 

먼저 object_detection폴더에 training 폴더를 만들어주세요.

그 후 object_detection/samples/configs 안에 있는 ssd_mobilenet_v1_pets.config 파일을 training 폴더로 복사해주세요. 

그리고 config파일의 내용을 수정해야합니다. 

 

num_classes는 detection할 label의 개수로 지정해주세요.

 

 

3. object-detecion.pbtxt 파일 작성

 

object_detection/data 폴더에 object-detection.pbtxt파일을 만들어 줍니다.

label1, 2에 원하는 라벨 이름으로 작성해주세요.

item{
  id:1
  name: 'label1'
}

item{
  id:2
  name: 'label2'
}

 

 

4. Training

 

이제 training할 준비가 모두 되었습니다. 

train.py 파일은 legacy 폴더안에 있습니다. 복사해서 object_detection 폴더에서 바로 사용하셔도 되고 legacy폴더안에서 사용하셔도 됩니다. 저는 legacy 폴더에 둔 채로 사용했습니다. 

object_detection 폴더에서 다음 명령어를 실행해주세요

 

object_detection> python legacy/train.py --logtostderr --train_dir=training/ --pipeline_config_path=training/ssd_mobilenet_v1_pets.config

 

실행하면 cmd창에 step과 loss값이 찍히는 것을 볼 수 있습니다. 

loss가 어느정도 0에 가까워지면 종료해주시면 됩니다. 

 

 

5. tensorboard로 확인하기

 

object_detection> tensorboard --logdir=./training

명령어를 입력하면 loss값을 그래프로 확인할 수 있습니다.