Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 34 additions & 58 deletions install.py
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might reccomend reconsidering the get-intsalled-version function defaults (no_cache_dir gets passed to the function with true rather than being default) Otherwise good code.

Original file line number Diff line number Diff line change
@@ -1,70 +1,46 @@
import launch
import sys

python = sys.executable
import pkg_resources

def get_installed_version(package_name):
try:
return pkg_resources.get_distribution(package_name).version
except pkg_resources.DistributionNotFound:
return None

def install_package(package_name, version_spec=None, uninstall_first=False, extra_index_url=None, no_cache_dir=False):
package_install_cmd = f"{package_name}{'==' + version_spec if version_spec else ''}"
if extra_index_url:
package_install_cmd += f" --extra-index-url {extra_index_url}"
if no_cache_dir:
package_install_cmd += " --no-cache-dir"

if uninstall_first and launch.is_installed(package_name):
launch.run(["python", "-m", "pip", "uninstall", "-y", package_name], f"removing {package_name}")

launch.run_pip(f"install {package_install_cmd}", package_name, live=True)


def install():
if not launch.is_installed("importlib_metadata"):
launch.run_pip("install importlib_metadata", "importlib_metadata", live=True)
from importlib_metadata import version

if launch.is_installed("tensorrt"):
if not version("tensorrt") == "9.0.1.post11.dev4":
launch.run(
["python", "-m", "pip", "uninstall", "-y", "tensorrt"],
"removing old version of tensorrt",
)
# TensorRT installation or upgrade
tensorrt_version = get_installed_version("tensorrt")
if not tensorrt_version or tensorrt_version != "9.3.0.post12.dev1":
# nvidia-cudnn-cu11 installation
if launch.is_installed("nvidia-cudnn-cu12") and get_installed_version("nvidia-cudnn-cu12") != "8.9.6.50":
install_package("nvidia-cudnn-cu12", "8.9.6.50", uninstall_first=True, no_cache_dir=True)
install_package("tensorrt", "9.3.0.post12.dev1", uninstall_first=True, extra_index_url="https://pypi.nvidia.com", no_cache_dir=True)

if not launch.is_installed("tensorrt"):
print("TensorRT is not installed! Installing...")
launch.run_pip(
"install nvidia-cudnn-cu11==8.9.4.25 --no-cache-dir", "nvidia-cudnn-cu11"
)
launch.run_pip(
"install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir",
"tensorrt",
live=True,
)
launch.run(
["python", "-m", "pip", "uninstall", "-y", "nvidia-cudnn-cu11"],
"removing nvidia-cudnn-cu11",
)

if launch.is_installed("nvidia-cudnn-cu11"):
if version("nvidia-cudnn-cu11") == "8.9.4.25":
launch.run(
["python", "-m", "pip", "uninstall", "-y", "nvidia-cudnn-cu11"],
"removing nvidia-cudnn-cu11",
)

# Polygraphy
# Polygraphy installation
if not launch.is_installed("polygraphy"):
print("Polygraphy is not installed! Installing...")
launch.run_pip(
"install polygraphy --extra-index-url https://pypi.ngc.nvidia.com",
"polygraphy",
live=True,
)
install_package("polygraphy", extra_index_url="https://pypi.ngc.nvidia.com", no_cache_dir=True)

# ONNX GS
# ONNX Graph Surgeon installation
if not launch.is_installed("onnx_graphsurgeon"):
print("GS is not installed! Installing...")
launch.run_pip("install protobuf==3.20.2", "protobuf", live=True)
launch.run_pip(
"install onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com",
"onnx-graphsurgeon",
live=True,
)
install_package("protobuf", "3.20.3", no_cache_dir=True)
install_package("onnx-graphsurgeon", extra_index_url="https://pypi.ngc.nvidia.com", no_cache_dir=True)

# OPTIMUM
# Optimum installation
if not launch.is_installed("optimum"):
print("Optimum is not installed! Installing...")
launch.run_pip(
"install optimum",
"optimum",
live=True,
)

install_package("optimum", no_cache_dir=True)

install()
install()
6 changes: 3 additions & 3 deletions scripts/trt.py
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Against whitespace changes

Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def get_profile_idx(self, p, model_name: str, model_type: ModelType) -> (int, in
) # TODO: max_embedding, just ignore?
if len(valid_models) == 0:
gr.Error(
f"""No valid profile found for ({model_name}) LOWRES. Please go to the TensorRT tab and generate an engine with the necessary profile.
f"""No valid profile found for ({model_name}) LOWRES. Please go to the TensorRT tab and generate an engine with the necessary profile.
If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net."""
)
return None, None
Expand All @@ -177,7 +177,7 @@ def get_profile_idx(self, p, model_name: str, model_type: ModelType) -> (int, in
) # TODO: max_embedding
if len(valid_models_hr) == 0:
gr.Error(
f"""No valid profile found for ({model_name}) HIRES. Please go to the TensorRT tab and generate an engine with the necessary profile.
f"""No valid profile found for ({model_name}) HIRES. Please go to the TensorRT tab and generate an engine with the necessary profile.
If using hires.fix, you need an engine for both the base and upscaled resolutions. Otherwise, use the default (torch) U-Net."""
)
merged_idx = [i for i, id in enumerate(idx) if id in idx_hr]
Expand Down Expand Up @@ -247,7 +247,7 @@ def process(self, p, *args):

if not sd_unet_option.model_name == p.sd_model_name:
gr.Error(
"""Selected torch model ({}) does not match the selected TensorRT U-Net ({}).
"""Selected torch model ({}) does not match the selected TensorRT U-Net ({}).
Please ensure that both models are the same or select Automatic from the SD UNet dropdown.""".format(
p.sd_model_name, sd_unet_option.model_name
)
Expand Down