ApostolosK commited on
Commit
8322740
·
verified ·
1 Parent(s): 4817013

Update svm_vgg_preprocessor.py

Browse files
Files changed (1) hide show
  1. svm_vgg_preprocessor.py +24 -16
svm_vgg_preprocessor.py CHANGED
@@ -1,21 +1,23 @@
 
1
  import torch
2
  import numpy as np
3
  from torchvision import transforms
4
  from torchvision.models import vgg16
5
  from PIL import Image
6
 
7
- # Initialize VGG model (matches your training setup)
8
  class FeatureExtractor:
9
  def __init__(self):
10
- self.vgg_model = vgg16(pretrained=True)
11
- self.vgg_model.eval() # Set to evaluation mode
 
12
 
13
- # For FC-CNN features (classifier-based)
14
- self.fc_extractor = torch.nn.Sequential(
15
- *list(self.vgg_model.classifier.children())[:-2] # Remove last 2 layers
 
16
  )
17
 
18
- # Standard VGG preprocessing
19
  self.preprocess = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
@@ -26,22 +28,28 @@ class FeatureExtractor:
26
  ])
27
 
28
  def extract_fc_cnn_features(self, image_path):
29
- """Extract fully-connected layer features"""
30
- image = Image.open(image_path).convert('RGB')
31
- image_tensor = self.preprocess(image).unsqueeze(0)
 
32
 
33
  with torch.no_grad():
34
- features = self.fc_extractor(image_tensor)
 
 
 
 
 
35
 
36
- return features.squeeze().numpy().flatten()
37
 
38
  def extract_fv_cnn_features(self, image_path):
39
- """Extract convolutional layer features"""
40
- image = Image.open(image_path).convert('RGB')
41
- image_tensor = self.preprocess(image).unsqueeze(0)
42
 
43
  with torch.no_grad():
44
- conv_features = self.vgg_model.features(image_tensor)
45
 
46
  return conv_features.squeeze().numpy().flatten()
47
 
 
1
+ # svm_vgg_preprocessor.py
2
  import torch
3
  import numpy as np
4
  from torchvision import transforms
5
  from torchvision.models import vgg16
6
  from PIL import Image
7
 
 
8
  class FeatureExtractor:
9
  def __init__(self):
10
+ # Load pretrained VGG16
11
+ self.vgg = vgg16(weights='DEFAULT')
12
+ self.vgg.eval()
13
 
14
+ # Feature extractors
15
+ self.conv_features = self.vgg.features
16
+ self.fc_features = torch.nn.Sequential(
17
+ *list(self.vgg.classifier.children())[:-2]
18
  )
19
 
20
+ # Preprocessing
21
  self.preprocess = transforms.Compose([
22
  transforms.Resize((224, 224)),
23
  transforms.ToTensor(),
 
28
  ])
29
 
30
  def extract_fc_cnn_features(self, image_path):
31
+ """Correct FC-CNN feature extraction"""
32
+ # Load and preprocess image
33
+ img = Image.open(image_path).convert('RGB')
34
+ img_tensor = self.preprocess(img).unsqueeze(0)
35
 
36
  with torch.no_grad():
37
+ # Get convolutional features
38
+ conv_out = self.conv_features(img_tensor)
39
+ # Flatten for FC layers
40
+ flattened = torch.flatten(conv_out, 1)
41
+ # Get FC features
42
+ fc_features = self.fc_features(flattened)
43
 
44
+ return fc_features.squeeze().numpy().flatten()
45
 
46
  def extract_fv_cnn_features(self, image_path):
47
+ """FV-CNN feature extraction (unchanged)"""
48
+ img = Image.open(image_path).convert('RGB')
49
+ img_tensor = self.preprocess(img).unsqueeze(0)
50
 
51
  with torch.no_grad():
52
+ conv_features = self.conv_features(img_tensor)
53
 
54
  return conv_features.squeeze().numpy().flatten()
55