"""
Doubao-Seed3D API implementation for sketch-to-3D model conversion.
"""
import numpy as np
from PIL import Image
import os
import base64
import io
import time
import requests
import zipfile
import tempfile
from typing import Optional, Dict, Any, Union
try:
from sketchkit.core.sketch import Sketch
except ImportError:
Sketch = None
[docs]
def sketch_to_3d_doubao_impl(
sketch_input: Union['Sketch', str, np.ndarray, Image.Image],
output_path: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs
) -> str:
"""
Convert sketch to 3D model using Doubao-Seed3D API.
Args:
sketch_input: Input sketch, can be Sketch object, file path (str), PIL Image, or numpy array
output_path: Output file path (if None, will be auto-generated)
api_key: Doubao API key (if None, will read from environment variable DOUBAO_API_KEY)
api_base: API base URL (if None, uses default)
**kwargs: Additional parameters
- model_name: Model name (default: "doubao-seed3d-1-0-250928")
- file_format: Output format (default: "glb")
- subdivision_level: Subdivision level (default: "medium")
- max_wait_time: Maximum wait time in seconds (default: 600)
- poll_interval: Polling interval in seconds (default: 5)
Returns:
Generated 3D model file path
"""
try:
# Handle different input types
if Sketch is not None and isinstance(sketch_input, Sketch):
# Sketch object - render to image
from sketchkit.renderer.cairo_renderer import CairoRenderer
renderer = CairoRenderer()
raster_image = renderer.render(sketch_input)
# Ensure raster_image is a numpy array
if isinstance(raster_image, Image.Image):
raster_image = np.array(raster_image)
if raster_image.ndim == 3:
raster_image = (raster_image[...,:3] * 255).astype(np.uint8) if raster_image.dtype != np.uint8 else raster_image[...,:3]
else:
raster_image = (raster_image * 255).astype(np.uint8) if raster_image.dtype != np.uint8 else raster_image
elif isinstance(sketch_input, str):
# File path - load image
img = Image.open(sketch_input).convert('RGB')
raster_image = np.array(img, dtype=np.uint8)
elif isinstance(sketch_input, Image.Image):
# PIL Image - convert to RGB and then to numpy array
img = sketch_input.convert('RGB')
raster_image = np.array(img, dtype=np.uint8)
else:
# Assume numpy array
raster_image = np.asarray(sketch_input, dtype=np.uint8)
if raster_image.ndim == 2:
# Grayscale to RGB
raster_image = np.stack([raster_image] * 3, axis=-1)
# Get API configuration
if api_key is None:
api_key = os.environ.get('DOUBAO_API_KEY')
model_name = kwargs.get('model_name', 'doubao-seed3d-1-0-250928')
subdivision_level = kwargs.get('subdivision_level', 'medium')
file_format = kwargs.get('file_format', 'glb')
max_wait_time = kwargs.get('max_wait_time', 600)
poll_interval = kwargs.get('poll_interval', 5)
# API endpoint
if api_base is None:
api_base = "https://ark.cn-beijing.volces.com/api/v3/contents/generations"
create_task_url = f"{api_base}/tasks"
# Convert image to base64
if len(raster_image.shape) == 3:
image = Image.fromarray(raster_image)
else:
image = Image.fromarray(raster_image, mode='L').convert('RGB')
# Save to memory as PNG
buffer = io.BytesIO()
image.save(buffer, format='PNG')
image_bytes = buffer.getvalue()
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
image_url = f"data:image/png;base64,{image_base64}"
# Build request payload
# Note: doubao-seed3d-1-0 model does not support text prompt, only image
payload = {
"model": model_name,
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
}
# Try to add parameters to payload (if API supports)
if subdivision_level:
payload["subdivision_level"] = subdivision_level
if file_format:
payload["file_format"] = file_format
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
# Create task
print(f"[Doubao Seed3D] Creating task...")
print(f"[Doubao Seed3D] Model: {model_name}, Format: {file_format}, Subdivision: {subdivision_level}")
response = requests.post(create_task_url, headers=headers, json=payload, timeout=30)
response.raise_for_status()
task_data = response.json()
task_id = task_data.get('id')
if not task_id:
raise ValueError(f"Failed to create task: No task ID found in response. Response: {task_data}")
print(f"[Doubao Seed3D] Task created, ID: {task_id}")
# Poll task status
query_url = f"{api_base}/tasks/{task_id}"
start_time = time.time()
while True:
elapsed_time = time.time() - start_time
if elapsed_time > max_wait_time:
raise TimeoutError(f"Task timeout: Wait time exceeded {max_wait_time} seconds")
print(f"[Doubao Seed3D] Querying task status... (waited {int(elapsed_time)} seconds)")
response = requests.get(query_url, headers=headers, timeout=30)
response.raise_for_status()
task_info = response.json()
status = task_info.get('status')
# Handle completion status (including 'completed' and 'succeeded')
if status == 'completed' or status == 'succeeded':
print(f"[Doubao Seed3D] Task completed! Status: {status}")
# Get result file URL
result_url = task_info.get('result_url') or task_info.get('output_url')
if not result_url:
# Try to get from content field (Doubao API response format)
content = task_info.get('content')
if isinstance(content, dict):
result_url = content.get('file_url') or content.get('url') or content.get('result_url')
if not result_url:
# Try to get from other fields
result = task_info.get('result') or task_info.get('output')
if isinstance(result, dict):
result_url = result.get('url') or result.get('file_url')
elif isinstance(result, str):
# If result is a string, it might be a URL
result_url = result
# If still not found, try to get from contents field (plural form)
if not result_url:
contents = task_info.get('contents', [])
if contents and isinstance(contents, list):
for content in contents:
if isinstance(content, dict):
result_url = content.get('url') or content.get('file_url') or content.get('result_url')
if result_url:
break
if not result_url:
# Print full response for debugging
print(f"[Doubao Seed3D] Debug info - Full response: {task_info}")
raise ValueError(f"Task completed but result URL not found. Response: {task_info}")
# Download result file
print(f"[Doubao Seed3D] Downloading result file: {result_url}")
file_response = requests.get(result_url, timeout=60)
file_response.raise_for_status()
# Save file
if output_path is None:
temp_dir = tempfile.gettempdir()
os.makedirs(temp_dir, exist_ok=True)
output_path = os.path.join(temp_dir, f"doubao_seed3d_{task_id}.{file_format}")
else:
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
# Check if downloaded file is a zip file (Doubao API may return zip)
content_type = file_response.headers.get('content-type', '')
is_zip = 'zip' in content_type.lower() or result_url.endswith('.zip')
if is_zip:
# If it's a zip file, extract and find target format file
print(f"[Doubao Seed3D] Detected zip file, extracting...")
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_zip:
tmp_zip.write(file_response.content)
tmp_zip_path = tmp_zip.name
try:
with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref:
# Find target format file
target_ext = f'.{file_format}'
found_file = None
for file_info in zip_ref.namelist():
if file_info.endswith(target_ext):
found_file = file_info
break
if found_file:
# Extract target file
with open(output_path, 'wb') as f:
f.write(zip_ref.read(found_file))
print(f"[Doubao Seed3D] File extracted and saved: {output_path}")
else:
# If target format not found, extract all files
print(f"[Doubao Seed3D] {target_ext} file not found, extracting all files...")
zip_ref.extractall(os.path.dirname(output_path) or '.')
# Find extracted files
extracted_files = [f for f in os.listdir(os.path.dirname(output_path) or '.') if f.endswith(target_ext)]
if extracted_files:
output_path = os.path.join(os.path.dirname(output_path) or '.', extracted_files[0])
print(f"[Doubao Seed3D] File extracted and saved: {output_path}")
else:
raise ValueError(f"Extracted files do not contain {target_ext} format file")
finally:
# Clean up temporary zip file
if os.path.exists(tmp_zip_path):
os.unlink(tmp_zip_path)
else:
# Save file directly
with open(output_path, 'wb') as f:
f.write(file_response.content)
print(f"[Doubao Seed3D] File saved: {output_path}")
return output_path
elif status == 'failed' or status == 'error':
error_msg = task_info.get('error') or task_info.get('message', 'Unknown error')
raise RuntimeError(f"Task failed: {error_msg}")
elif status in ['pending', 'processing', 'running', 'queued']:
# Continue waiting
time.sleep(poll_interval)
else:
# Unknown status, wait and retry
print(f"[Doubao Seed3D] Warning: Unknown status '{status}', continuing to wait...")
time.sleep(poll_interval)
except requests.exceptions.RequestException as e:
error_msg = str(e)
if hasattr(e, 'response') and e.response is not None:
try:
error_detail = e.response.json()
error_msg = f"{error_msg}\nResponse details: {error_detail}"
except:
error_msg = f"{error_msg}\nResponse content: {e.response.text}"
raise RuntimeError(f"Doubao Seed3D API call failed: {error_msg}")
except Exception as e:
print(f"Doubao Seed3D method failed: {e}")
import traceback
traceback.print_exc()
raise