An interpretable classifier for high-resolution breast cancer screening images utilizing weakly supervised localization [GMIC]
- paper
An interpretable classifier for high-resolution breast cancer screening images utilizing weakly supervised localization
Medical images differ from natural images in significantly higher resolutions and smaller regions of interest. Because of these differences, neural ne…
www.sciencedirect.com
- code
GitHub - nyukat/GMIC: An interpretable classifier for high-resolution breast cancer screening images utilizing weakly supervised
An interpretable classifier for high-resolution breast cancer screening images utilizing weakly supervised localization - GitHub - nyukat/GMIC: An interpretable classifier for high-resolution breas...
github.com
- Abstrat
- Medical images는 natural image와 다르게 해상도는 더 크고 관심 영역은 더 작다는 점에서 기존에 natural image에 잘 작동하는 neural network는 medical image 분석에 적합하지 않다.
- 이런 특성을 다루기 위해 globally-aware multiple instance classifier(GMIC) 사용한다.
- GMIC는 다음과 같이 3가지 모듈로 이루어진다.
1. low-capacity, memory-efficient, 가장 정보가 많은 부분을 구별하기 위해 whole image를 보는 network (global module)
2. higher-capacity, 선택된 정보가 많은 부분으로부터 디테일을 모으기 위한 network (local module)
3. global & local information을 합쳐 마지막 prediction을 만드는 fusion module
- 다른 기존 모델이 학습에 lesion segmentation을 필요로 하는 것과 다르게 오직 image-level labels로 학습되고 악성 가능성을 나타내는 pixel-level의 saliency map을 생성할 수 있다.
- weakly supervised localization..
- 유방암 데이터셋 (NYU Breast Cancer Screening Dataset)을 사용하였으며, AUC=0.93으로 resnet-34, Faster R-CNN 보다 성능이 높다.
- Introduction
- 방사선사 vs. AI model 비교한 결과 AI auc가 더 높음
- GMIC와 faster R-CNN & ResNet-34를 앙상블하는 것이 성능을 얼마나 높일 수 있는지 실험해 봄
- Segmentation labels로도 GMIC 모델 돌려봄.
→ 두 실험 결과 성능 개선이 미미했으며(marginal) large training set에서는 image-level labels 만으로도 충분히 괜찮은 성능이 나옴
- Methods
- gray scale image 사용 : (H, W, 1)
- image-level label : benign (0) / malignant (1)
▶ Globally-Aware Classification Framework
- 방사선사가 진단하는 것과 유사한 구조로 이루어짐 (global, local)
1. Global module
- global network fg를 사용하여 input image x로부터 feature map hg를 추출
- 방사선사가 처음엔 전체 이미지를 대충 스캔하는 것과 유사 (roughly scanning)
- 1x1 conv layer와 sigmoid 활성화함수를 적용하여 feature map hg를 두 개의 saliency maps A로 변환 (A benign, A malignant)
- saliency maps로 benign과 malignant의 대략적인 위치를 보여줌
- saliency maps(A)에서 각 요소는 label을 예측하기 위한 (i, j) 공간 위치에서의 기여도를 나타냄
*saliency map : conv layer가 어디에 집중해서 결정을 내렸는지 보여주는 시각화 기술 (grad-cam도 여기에 포함됨)
*1x1 conv : 1x1 conv는 우리가 원하는 channel 수를 가질 수 있도록 함(네트워크가 깊어짐). 즉, 논문에서는 각 label에 대한 saliency map을 보고 싶기 때문에 hg를 2channel로 차원 축소.
- gpu 메모리 제한으로 이전 논문에서는 이미지를 down-sampling하고, global network의 복잡도를 줄임.
- 그 결과, 병변 마진과 같은 중요한 visual details를 왜곡하고, 이미지에 포함된 미묘한 패턴을 capture할 수 없음
→ high-capacity local network 사용 (to extract fine-grained details from a set of informative regions)
- A를 사용하여 input image x로부터 K개의 informative patches를 검색함
*retrieve_roi : heuristic patch-selection procedure (뒤에 설명)
2. Local module
- 선택된 patches xk들은 local network fl를 지나 fine-grained visual features hk가 추출됨
- aggregator fa(=attention module?)를 통해 vector z (=attention-weighted representation)로 합쳐짐
3. Fusion module
- global structure hg + local details z -> prediction y^
▶ Model Parameterization
1. Generating the Saliency Maps.
- 계산량이 적으면서 고해상도를 다루기 위해 global network fg를 ResNet-22로 매개변수화
→ 기준 ResNet과 비교하여 ResNet-22는 residual block이 하나 더 있고, 각 conv layer의 필터는 1/4로 더 적다. 따라서 더 깊어진 CNN으로 고해상도 이미지에서 복잡한 features를 capture할 수 있고, 좁아진 network는 hidden units의 개수를 줄여 계산량을 감소시킨다.
- label y는 위치 정보가 없기 때문에 Saliency maps A와 y를 직접 비교하는 것은 어려움 (pooling - fc layer, flatten의 과정이 있어야 함)
- 따라서 aggregation function fagg를 사용하여 A를 image-level class prediction으로 변환하고 이렇게 구해진 yc_global과 y의 loss를 사용하여 fg를 학습시킴
- fagg으로 GAP (global average pooling)와 GMP (global max pooling) 사이 정도의 top t% pooling을 제안함
- GAP는 saliency maps A의 대부분 공간 위치가 배경에 해당하고 훈련 신호를 거의 제공하지 않기 때문에 예측을 희석시킴
- GMP는 하나의 공간 위치를 사용하므로 학습 과정이 느리고 불안정함.
→ 따라서 top t% pooling 사용 (t=hyper-parameter, t=100%=GAP, t=1/hxw=GMP)
2. Acquiring ROI patches. (roi 패치 획득)
- K개의 patches (xk)를 찾기 위해 ‘retrieve_roi’는 greedy algorithm 사용
- 임의의 사각 패치 l마다 계산한 A*의 합이 가장 커지는 bbox를 선택한 다음 input image에 일치하는 위치에다 bbox를 매핑시킴
- 이 때, 각 패치들은 서로 많이 overlap되지 않도록 함
3. Utilizing Information from Patches.
- 이제 뽑아낸 patches (xk)들을 local network fl에 적용할 수 있다.
- Patches의 size=([batch_size, 6, 256, 256])
- 패치 개수 6은 hyper-parameter로써 변경할 수 있음
- 패치를 [batch*6, 1, 256, 256]으로 차원 변경하여 lobal module 입력
- Local network로는 ResNet-18, ResNet-34, ResNet-50 등을 사용
- coarse Saliency maps (=러프한 정보를 가짐)로부터 검색되었기 때문에 각 ROI 패치에 있는 분류와 관련된 정보들이 상당히 다름
- 이런 문제를 다루기 위해 모델이 패치로부터 선택적으로 정보를 통합시킬 수 있는 Gated Attention Mechanism(GA)을 사용
- 즉, 각각의 패치가 가진 분류와 관련된 정보가 상당히 다르므로 분류와 관련된 정보가 높은 패치에 attention을 주겠다.
- 다른 attention 메커니즘과 다르게 GA는 sigmoid 함수를 사용해서 비선형성을 증가시킨다고 함
- 각 패치마다 계산된 Attention score ak와 feature vector hk를 모두 더하면 attention-weighted representation z로 표현됨
- z는 fc layer와 sigmoid 활성화 함수를 지나 yc_local을 생성하고, yc_local과 y의 loss를 사용하여 fl을 학습시킴
4. Information Fusion
- Saliency maps와 ROI patches로의 정보를 합치기 위해 "global max pooling을 적용한 hg"와 "z"를 concatenate 함.
- 합쳐진 representation은 fc layer, sigmoid를 거쳐 final prediction을 생성
▶ Learning the parameters of GMIC
- Saliency maps Ac에 L1 regularization 적용하여 sparser하게 만듦 (중요한 부분만 강조하게)
- 모델은 다음 loss function(BCE)과 SGD를 사용하여 복잡한 위 framework를 end-to-end로 학습
- global, local, fusion 모두 한번에 학습
*L1 regularization (=cam_loss) 코드는 github 참고
▽ Results
Table 2:
- Global and local features를 합치는 것의 효율성을 보기 위해 4가지 조합의 AUC를 비교해본 결과, yc_fusion이 가장 성능이 좋음
Table 4:
- saliency maps의 aggregation function의 효과를 비교하기 위해 GMP, GAP, t=1~20% 파라미터를 비교해본 결과, t=2~10%로 설정하는 것이 좋음
Fig. 10:
- ROI patches의 개수 K를 [1,2,3,4,6,8,10] 준 결과, k>3일 때 성능이 최고점에 이름
▽ Attention module
- sigmoid를 사용하여 비선형성 증가
- K개의 patches 개수만큼 softmax 값이 출력됨