Josh Brown Kramer commited on
Commit
5ae76ab
·
1 Parent(s): 36ad354

More basic example

Browse files
Files changed (2) hide show
  1. app.py +19 -18
  2. requirements.txt +1 -5
app.py CHANGED
@@ -1,27 +1,28 @@
1
  import gradio as gr
2
- import torch
3
- from your_pix2pixhd_code import YourPix2PixHDModel, load_image, tensor2im # Adapt these imports
4
 
5
- # --- 1. Load your pix2pixHD model ---
6
- # You'll need to adapt this part to your specific model loading logic
7
- # This is a simplified example
8
- model = YourPix2PixHDModel()
9
- model.load_state_dict(torch.load('models/your_pix2pixhd_model.pth'))
10
- model.eval()
11
 
12
  # --- 2. Define the prediction function ---
13
  def predict(input_image):
14
- # Pre-process the input image
15
- processed_image = load_image(input_image)
 
16
 
17
- # Run inference
18
- with torch.no_grad():
19
- generated_image_tensor = model(processed_image)
20
 
21
- # Post-process the output tensor to an image
22
- output_image = tensor2im(generated_image_tensor)
23
 
24
- return output_image
25
 
26
  # --- 3. Create the Gradio Interface ---
27
  title = "pix2pixHD Image-to-Image Translation"
@@ -30,8 +31,8 @@ article = "<p style='text-align: center'>Model based on the <a href='https://git
30
 
31
  gr.Interface(
32
  fn=predict,
33
- inputs=gr.Image(type="pil", label="Input Image"),
34
- outputs=gr.Image(type="pil", label="Output Image"),
35
  title=title,
36
  description=description,
37
  article=article,
 
1
  import gradio as gr
2
+ # import torch
3
+ # from your_pix2pixhd_code import YourPix2PixHDModel, load_image, tensor2im # Adapt these imports
4
 
5
+ # # --- 1. Load your pix2pixHD model ---
6
+ # # You'll need to adapt this part to your specific model loading logic
7
+ # # This is a simplified example
8
+ # model = YourPix2PixHDModel()
9
+ # model.load_state_dict(torch.load('models/your_pix2pixhd_model.pth'))
10
+ # model.eval()
11
 
12
  # --- 2. Define the prediction function ---
13
  def predict(input_image):
14
+ return 255 - input_image
15
+ # # Pre-process the input image
16
+ # processed_image = load_image(input_image)
17
 
18
+ # # Run inference
19
+ # with torch.no_grad():
20
+ # generated_image_tensor = model(processed_image)
21
 
22
+ # # Post-process the output tensor to an image
23
+ # output_image = tensor2im(generated_image_tensor)
24
 
25
+ # return output_image
26
 
27
  # --- 3. Create the Gradio Interface ---
28
  title = "pix2pixHD Image-to-Image Translation"
 
31
 
32
  gr.Interface(
33
  fn=predict,
34
+ inputs=gr.Image(type="numpy", label="Input Image"),
35
+ outputs=gr.Image(type="numpy", label="Output Image"),
36
  title=title,
37
  description=description,
38
  article=article,
requirements.txt CHANGED
@@ -1,5 +1 @@
1
- gradio==4.31.0
2
- torch==2.1.0
3
- torchvision==0.16.0
4
- numpy==1.26.4
5
- Pillow==10.2.0
 
1
+ gradio