{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# JASCO\n", "Welcome to JASCO's demo jupyter notebook. \n", "Here you will find a self-contained example of how to use JASCO for temporally controlled music generation.\n", "\n", "You can choose a model from the following selection:\n", "1. facebook/jasco-chords-drums-400M - 10s music generation conditioned on text, chords and drums, 400M parameters\n", "2. facebook/jasco-chords-drums-1B - 10s music generation conditioned on text, chords and drums, 1B parameters\n", "\n", "\n", "First, we start by initializing the JASCO model:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/private/home/ortal1/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.\n", " @torch.library.impl_abstract(\"xformers_flash::flash_fwd\")\n", "/private/home/ortal1/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.\n", " @torch.library.impl_abstract(\"xformers_flash::flash_bwd\")\n", "/private/home/ortal1/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "/checkpoint/ortal1/Projects/jasco_release/audiocraft/models/loaders.py:71: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", " return torch.load(file, map_location=device)\n", "/private/home/ortal1/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/transformers/models/encodec/modeling_encodec.py:124: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " self.register_buffer(\"padding_total\", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)\n" ] } ], "source": [ "import os \n", "from audiocraft.models import JASCO\n", "\n", "chords_mapping_path = os.path.abspath('./assets/chord_to_index_mapping.pkl')\n", "model = JASCO.get_pretrained('facebook/jasco-chords-drums-1B', chords_mapping_path='./assets/chord_to_index_mapping.pkl')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let us configure the generation parameters. Specifically, you can control the following:\n", "* `cfg_coef_all` (float, optional): Coefficient used for classifier free guidance - fully conditional term. \n", " Defaults to 5.0.\n", "* `cfg_coef_txt` (float, optional): Coefficient used for classifier free guidance - additional text conditional term. \n", " Defaults to 0.0.\n", "\n", "When left unchanged, JASCO will revert to its default parameters." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model.set_generation_params(\n", " cfg_coef_all=0.0,\n", " cfg_coef_txt=5.0\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can go ahead and start generating music given textual prompts." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Text-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from audiocraft.utils.notebook import display_audio\n", "\n", "# set textual prompt\n", "text = \"Funky groove with electric piano playing blue chords rhythmically\"\n", "\n", "# run the model\n", "print(\"Generating...\") \n", "output = model.generate(descriptions=[text], progress=True)\n", "\n", "# display the result\n", "print(f\"Text: {text}\\n\")\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can start adding temporal controls! We begin with conditioning on chord progressions:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Chords-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from audiocraft.utils.notebook import display_audio\n", "\n", "model.set_generation_params(\n", " cfg_coef_all=1.5,\n", " cfg_coef_txt=2.5\n", ")\n", "\n", "# set textual prompt\n", "text = \"Strings, woodwind, orchestral, symphony.\"\n", "\n", "# define chord progression\n", "chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)]\n", "\n", "# run the model\n", "print(\"Generating...\")\n", "output = model.generate_music(descriptions=[text], chords=chords, progress=True)\n", "\n", "# display the result\n", "print(f'Text: {text}')\n", "print(f'Chord progression: {chords}')\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can condition the generation on drum tracks:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Drums-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchaudio\n", "from audiocraft.utils.notebook import display_audio\n", "\n", "\n", "# load drum prompt\n", "drums_waveform, sr = torchaudio.load(\"./assets/sep_drums_1.mp3\")\n", "\n", "# set textual prompt \n", "text = \"distortion guitars, heavy rock, catchy beat\"\n", "\n", "# run the model\n", "print(\"Generating...\")\n", "output = model.generate_music(\n", " descriptions=[text],\n", " drums_wav=drums_waveform,\n", " drums_sample_rate=sr,\n", " progress=True\n", ")\n", "\n", "# display the result\n", "print('drum prompt:')\n", "display_audio(drums_waveform, sample_rate=sr)\n", "print(f'Text: {text}')\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also combine multiple temporal controls! Let's move on to generating with both chords and drums conditioning:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Drums + Chords conditioning" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchaudio\n", "from audiocraft.utils.notebook import display_audio\n", "\n", "\n", "# load drum prompt\n", "drums_waveform, sr = torchaudio.load(\"./assets/sep_drums_1.mp3\")\n", "\n", "# set textual prompt \n", "text = \"string quartet, orchestral, dramatic\"\n", "\n", "# define chord progression\n", "chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)]\n", "\n", "# run the model\n", "print(\"Generating...\")\n", "output = model.generate_music(\n", " descriptions=[text],\n", " drums_wav=drums_waveform,\n", " drums_sample_rate=sr,\n", " chords=chords,\n", " progress=True\n", ")\n", "\n", "# display the result\n", "print('drum prompt:')\n", "display_audio(drums_waveform, sample_rate=sr)\n", "print(f'Chord progression: {chords}')\n", "print(f'Text: {text}')\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Melody + Drums + Chords conditioning - inference example" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Source melody:\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABj0AAADQCAYAAABcDaP2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYDElEQVR4nO3df2jc9f0H8NfFtKm2udTUNVlpioXJanGt2GobHPuhmdkoY8UOHJStlqIwErFmbKOwVTYGEQfqCv5iP3R/WCsddGKZSqkuMoxaUwrVzbKB0GCXRJEmNdC0Np/vH6P37fVnkia5yzuPBxz0Pvfu3evzufe97uzT9+eTy7IsCwAAAAAAgCmuotQFAAAAAAAAjAehBwAAAAAAkAShBwAAAAAAkAShBwAAAAAAkAShBwAAAAAAkAShBwAAAAAAkAShBwAAAAAAkAShBwAAAAAAkITKUhdwtuHh4Thy5EhUV1dHLpcrdTkAAAAAAEAJZVkWx44diwULFkRFxcXXcpRd6HHkyJFoaGgodRkAAAAAAEAZ6e7ujoULF150TNmFHtXV1RHxv+Lz+XyJqwEAAAAAAEppYGAgGhoaCvnBxZRd6HH6lFb5fF7oAQAAAAAARESM6JIYLmQOAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAk4bJCj4ceeihyuVxs3ry5sO348ePR0tIS8+bNizlz5sS6deuit7f3cusEAAAAAAC4qDGHHvv27Yunn346li1bVrT9gQceiJdeeil27twZHR0dceTIkbjzzjsvu1AAAAAAAICLGVPo8dlnn8X69evj97//fVx99dWF7f39/fHHP/4xHnnkkbjttttixYoV8cwzz8Sbb74Zb7311rgVDQAAAAAAcLYxhR4tLS2xZs2aaGpqKtre1dUVJ0+eLNq+ZMmSWLRoUXR2dl5epQAAAAAAABdROdq/sGPHjti/f3/s27fvnMd6enpi5syZMXfu3KLtdXV10dPTc97nGxoaiqGhocL9gYGB0ZYEAAAAAAAwupUe3d3dcf/998dzzz0Xs2bNGpcC2tvbo6ampnBraGgYl+cFAAAAAACml1GFHl1dXdHX1xc33XRTVFZWRmVlZXR0dMS2bduisrIy6urq4sSJE3H06NGiv9fb2xv19fXnfc4tW7ZEf39/4dbd3T3mnQEAAAAAAKavUZ3e6vbbb4+DBw8Wbdu4cWMsWbIkfv7zn0dDQ0PMmDEj9u7dG+vWrYuIiEOHDsXhw4ejsbHxvM9ZVVUVVVVVYywfAAAAAADgf0YVelRXV8cNN9xQtG327Nkxb968wvZNmzZFW1tb1NbWRj6fj/vuuy8aGxtj9erV41c1AAAAAADAWUZ9IfNLefTRR6OioiLWrVsXQ0ND0dzcHE888cR4vwwAAAAAAECRXJZlWamLONPAwEDU1NREf39/5PP5UpcDAAAAAACU0Ghyg1FdyBwAAAAAAKBcCT0AAAAAAIAkCD0AAAAAAIAkCD0AAAAAAIAkVJa6ACZXLpeb8NfIsuySr32hMUy+iZoTI5kH4/WcAACjNRm/iy+H3z2lNdnzw39DAQClUsrfxRP1+8ZKDwAAAAAAIAlCDwAAAAAAIAlObzXNlHJJtOXY5Wmy3xfzAAAoB36TcDHlMj/KpQ4AIF0p/t6w0gMAAAAAAEiC0AMAAAAAAEiC0AMAAAAAAEiC0AMAAAAAAEiC0AMAAAAAAEhCZakLSFkulxvV+CzLJqgSSmm082CipDi/JvrYpnjMUlcun7eRML+AlJWyH+uv5aUcv5un0xy5nOM/nY4TAIynifr9c+Z383i+Rorf+VZ6AAAAAAAASRB6AAAAAAAASXB6qwmU4tIgRs88mDiOLWczJwDKg37MaeZCaTn+ADD5JuP713f8xVnpAQAAAAAAJEHoAQAAAAAAJMHprc5yoSvfWzKUvgu99yNljqRnqs2Jy633fCZqHyai1gifQ4CRmqg+fJp+PPX4buZs4zknJmMeTKXfwgCUp/H6LvH9UXpWegAAAAAAAEkQegAAAAAAAElwequzWH40fXnvOdtUmxNTqd6pVCtAivRhzmZOcLapNiemWr0AlB/fJemw0gMAAAAAAEiC0AMAAAAAAEiC0AMAAAAAAEjClLqmRy6XG7fnco62NI12jpgHlMJ49rIzTcR8nkq1AqRIH2YsxjJvzAnKSbnM4YnqwRPlzGMwktrL/XN/sX0o99qB8THZfVhvSYeVHgAAAAAAQBKEHgAAAAAAQBKm1OmtLDHiUswRpoKpNE+nUq0AKdKHGQvzhqmuXOZwudQxFlO59tNS2Afg8ugDjJWVHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBKEHgAAAAAAQBIqS10AAAAAAMB0kMvlCn/OsqyElUC6rPQAAAAAAACSIPQAAAAAAACSIPQAAAAAAACSMKrQo729PW6++eaorq6O+fPnx9q1a+PQoUNFY44fPx4tLS0xb968mDNnTqxbty56e3vHtWgAAAAAgKkmy7LCDZgYowo9Ojo6oqWlJd56663Ys2dPnDx5Mu64444YHBwsjHnggQfipZdeip07d0ZHR0ccOXIk7rzzznEvHAAAAAAA4Ey57DJixY8//jjmz58fHR0d8bWvfS36+/vjC1/4Qmzfvj2+//3vR0TEBx98ENdff310dnbG6tWrL/mcAwMDUVNTE/39/ZHP58daGgAAAAAAkIDR5AaXdU2P/v7+iIiora2NiIiurq44efJkNDU1FcYsWbIkFi1aFJ2dnZfzUgAAAAAAABdVOda/ODw8HJs3b45bb701brjhhoiI6OnpiZkzZ8bcuXOLxtbV1UVPT895n2doaCiGhoYK9wcGBsZaEgAAAAAAMI2NeaVHS0tLvPfee7Fjx47LKqC9vT1qamoKt4aGhst6PgAAAAAAYHoaU+jR2toau3fvjtdffz0WLlxY2F5fXx8nTpyIo0ePFo3v7e2N+vr68z7Xli1bor+/v3Dr7u4eS0kAAAAAAMA0N6rQI8uyaG1tjV27dsVrr70WixcvLnp8xYoVMWPGjNi7d29h26FDh+Lw4cPR2Nh43uesqqqKfD5fdAMAAAAAABitUV3To6WlJbZv3x4vvvhiVFdXF67TUVNTE1deeWXU1NTEpk2boq2tLWprayOfz8d9990XjY2NsXr16gnZAQAAAAAAgIiIXJZl2YgH53Ln3f7MM8/E3XffHRERx48fj5/85Cfx/PPPx9DQUDQ3N8cTTzxxwdNbnW1gYCBqamqiv7/fqg8AAAAAAJjmRpMbjCr0mAxCDwAAAAAA4LTR5AZjupA5AAAAAABAuRnVNT0mU01NzTnbSrko5UKn9rqYyax3pPWV2cKeMbnYvo52/858rhSOTTm50PvkOMPkGcl3w3h+Ji+np47le3aijKT2yT62AOfjtyxQjkb7u24q9y99GKA8WekBAAAAAAAkQegBAAAAAAAkoWxPb1VuFzIv92WK5V7feBrPfZ1Ox22yObZQepP9Obyc15tqPWOq1QukSS8CytF06k3TaV8BphIrPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCRUlroAAAAARiaXyxX+nGXZJbczNZz5/p3pQu/xSJkLcK6J6pcj+RxTWiPpo1Pt/Upxn2A8WOkBAAAAAAAkQegBAAAAAAAkwemtAAAApogLnaLCqSumtpG8f95jGB8T9VnyGS1/Kb5HKe4TjAcrPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCQIPQAAAAAAgCRUlroAAADSlMvlRjQuy7IJrgQAAIDpwkoPAAAAAAAgCUIPAAAAAAAgCU5vBQDAhHDaKgAAACablR4AAAAAAEAShB4AAAAAAEASnN6KcZfL5c67PfVTXJT7fpd7fSlyzBmLC82bsZiouWZuAwAAAJdrJP8GMpZ/a7DSAwAAAAAASILQAwAAAAAASILTWzHupuvpTcp9v8u9vhQ55ozFVJg3U6FGAAAAoLxN1L8vWOkBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkQegBAAAAAAAkobLUBQAAAMBkyeVylxyTZdkkVDJ9jOSYn8nxh/Mb7WdpMvi8AuXISg8AAAAAACAJQg8AAAAAACAJExZ6PP7443HttdfGrFmzYtWqVfHOO+9M1EsBAADAiGRZdskb42skx9zxh0sb7WdpMm4A5WhCQo8XXngh2tra4sEHH4z9+/fH8uXLo7m5Ofr6+ibi5QAAAAAAACYm9HjkkUfinnvuiY0bN8bSpUvjqaeeiquuuir+9Kc/TcTLAQAAAAAAjH/oceLEiejq6oqmpqb/f5GKimhqaorOzs7xfjkAAAAAAICIiKgc7yf85JNP4tSpU1FXV1e0va6uLj744INzxg8NDcXQ0FDh/sDAwHiXBAAAAAAATAPjHnqMVnt7e/zqV786Z7vwAwAAAAAAOJ0XZFl2ybHjHnpcc801ccUVV0Rvb2/R9t7e3qivrz9n/JYtW6Ktra1w/6OPPoqlS5dGQ0PDeJcGAAAAAABMUceOHYuampqLjhn30GPmzJmxYsWK2Lt3b6xduzYiIoaHh2Pv3r3R2tp6zviqqqqoqqoq3J8zZ050d3dHlmWxaNGi6O7ujnw+P95lAky6gYGBaGho0NeAZOhrQGr0NSA1+hqQiizL4tixY7FgwYJLjp2Q01u1tbXFhg0bYuXKlXHLLbfEY489FoODg7Fx48ZL/t2KiopYuHBhYblKPp/XlIGk6GtAavQ1IDX6GpAafQ1IwaVWeJw2IaHHXXfdFR9//HFs3bo1enp64sYbb4xXXnnlnIubAwAAAAAAjJcJu5B5a2vreU9nBQAAAAAAMBEqSl3AhVRVVcWDDz5YdL0PgKlMXwNSo68BqdHXgNToa8B0lMuyLCt1EQAAAAAAAJerbFd6AAAAAAAAjIbQAwAAAAAASILQAwAAAAAASILQAwAAAAAASEJZhh6PP/54XHvttTFr1qxYtWpVvPPOO6UuCeC83njjjfjud78bCxYsiFwuF3/961+LHs+yLLZu3Rpf/OIX48orr4ympqb497//XTTm008/jfXr10c+n4+5c+fGpk2b4rPPPpvEvQD4f+3t7XHzzTdHdXV1zJ8/P9auXRuHDh0qGnP8+PFoaWmJefPmxZw5c2LdunXR29tbNObw4cOxZs2auOqqq2L+/Pnx05/+ND7//PPJ3BWAiIh48sknY9myZZHP5yOfz0djY2O8/PLLhcf1NGAqe+ihhyKXy8XmzZsL2/Q1YLoru9DjhRdeiLa2tnjwwQdj//79sXz58mhubo6+vr5SlwZwjsHBwVi+fHk8/vjj53384Ycfjm3btsVTTz0Vb7/9dsyePTuam5vj+PHjhTHr16+P999/P/bs2RO7d++ON954I+69997J2gWAIh0dHdHS0hJvvfVW7NmzJ06ePBl33HFHDA4OFsY88MAD8dJLL8XOnTujo6Mjjhw5EnfeeWfh8VOnTsWaNWvixIkT8eabb8af//znePbZZ2Pr1q2l2CVgmlu4cGE89NBD0dXVFe+++27cdttt8b3vfS/ef//9iNDTgKlr37598fTTT8eyZcuKtutrwLSXlZlbbrkla2lpKdw/depUtmDBgqy9vb2EVQFcWkRku3btKtwfHh7O6uvrs9/+9reFbUePHs2qqqqy559/PsuyLPvnP/+ZRUS2b9++wpiXX345y+Vy2UcffTRptQNcSF9fXxYRWUdHR5Zl/+tjM2bMyHbu3FkY869//SuLiKyzszPLsiz729/+llVUVGQ9PT2FMU8++WSWz+ezoaGhyd0BgPO4+uqrsz/84Q96GjBlHTt2LLvuuuuyPXv2ZF//+tez+++/P8syv9UAsizLymqlx4kTJ6KrqyuampoK2yoqKqKpqSk6OztLWBnA6H344YfR09NT1NNqampi1apVhZ7W2dkZc+fOjZUrVxbGNDU1RUVFRbz99tuTXjPA2fr7+yMiora2NiIiurq64uTJk0W9bcmSJbFo0aKi3vaVr3wl6urqCmOam5tjYGCg8H9WA5TCqVOnYseOHTE4OBiNjY16GjBltbS0xJo1a4r6V4TfagAREZWlLuBMn3zySZw6daqo6UZE1NXVxQcffFCiqgDGpqenJyLivD3t9GM9PT0xf/78oscrKyujtra2MAagVIaHh2Pz5s1x6623xg033BAR/+tbM2fOjLlz5xaNPbu3na/3nX4MYLIdPHgwGhsb4/jx4zFnzpzYtWtXLF26NA4cOKCnAVPOjh07Yv/+/bFv375zHvNbDaDMQg8AAMpHS0tLvPfee/GPf/yj1KUAXJYvf/nLceDAgejv74+//OUvsWHDhujo6Ch1WQCj1t3dHffff3/s2bMnZs2aVepyAMpSWZ3e6pprrokrrrgient7i7b39vZGfX19iaoCGJvTfetiPa2+vj76+vqKHv/888/j008/1feAkmptbY3du3fH66+/HgsXLixsr6+vjxMnTsTRo0eLxp/d287X+04/BjDZZs6cGV/60pdixYoV0d7eHsuXL4/f/e53ehow5XR1dUVfX1/cdNNNUVlZGZWVldHR0RHbtm2LysrKqKur09eAaa+sQo+ZM2fGihUrYu/evYVtw8PDsXfv3mhsbCxhZQCjt3jx4qivry/qaQMDA/H2228XelpjY2McPXo0urq6CmNee+21GB4ejlWrVk16zQBZlkVra2vs2rUrXnvttVi8eHHR4ytWrIgZM2YU9bZDhw7F4cOHi3rbwYMHi0LdPXv2RD6fj6VLl07OjgBcxPDwcAwNDelpwJRz++23x8GDB+PAgQOF28qVK2P9+vWFP+trwHRXdqe3amtriw0bNsTKlSvjlltuicceeywGBwdj48aNpS4N4ByfffZZ/Oc//ync//DDD+PAgQNRW1sbixYtis2bN8dvfvObuO6662Lx4sXxy1/+MhYsWBBr166NiIjrr78+vv3tb8c999wTTz31VJw8eTJaW1vjBz/4QSxYsKBEewVMZy0tLbF9+/Z48cUXo7q6unBe55qamrjyyiujpqYmNm3aFG1tbVFbWxv5fD7uu+++aGxsjNWrV0dExB133BFLly6NH/7wh/Hwww9HT09P/OIXv4iWlpaoqqoq5e4B09CWLVviO9/5TixatCiOHTsW27dvj7///e/x6quv6mnAlFNdXV241tpps2fPjnnz5hW262vAdFd2ocddd90VH3/8cWzdujV6enrixhtvjFdeeeWcCywBlIN33303vvnNbxbut7W1RUTEhg0b4tlnn42f/exnMTg4GPfee28cPXo0vvrVr8Yrr7xSdO7V5557LlpbW+P222+PioqKWLduXWzbtm3S9wUgIuLJJ5+MiIhvfOMbRdufeeaZuPvuuyMi4tFHHy30q6GhoWhubo4nnniiMPaKK66I3bt3x49//ONobGyM2bNnx4YNG+LXv/71ZO0GQEFfX1/86Ec/iv/+979RU1MTy5Yti1dffTW+9a1vRYSeBqRHXwOmu1yWZVmpiwAAAAAAALhcZXVNDwAAAAAAgLESegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEkQegAAAAAAAEn4P8oFZFHDCXiRAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Chords:\n", "[('N', 0.0), ('C', 0.32), ('Dm7', 3.456), ('Am', 4.608), ('F', 8.32), ('C', 9.216)]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Separated drums:\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Generating...\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import torchaudio \n", "from audiocraft.models import JASCO\n", "from demucs import pretrained\n", "from demucs.apply import apply_model\n", "from demucs.audio import convert_audio\n", "import torch\n", "from audiocraft.utils.notebook import display_audio\n", "import matplotlib.pyplot as plt\n", "\n", "# --------------------------\n", "# First, choose file to load\n", "# --------------------------\n", "fnames = ['salience_1', 'salience_2']\n", "chords = [\n", " [('N', 0.0), ('Eb7', 1.088000000), ('C#', 4.352000000), ('D', 4.864000000), ('Dm7', 6.720000000), ('G7', 8.256000000), ('Am7b5/G', 9.152000000)], # for salience 1\n", " [('N', 0.0), ('C', 0.320000000), ('Dm7', 3.456000000), ('Am', 4.608000000), ('F', 8.320000000), ('C', 9.216000000)] # for salience 2\n", "]\n", "file_idx = 1 # either 0 or 1\n", "\n", "\n", "# ------------------------------------\n", "# display audio, melody map and chords\n", "# ------------------------------------\n", "def plot_chromagram(tensor):\n", " # Check if tensor is a PyTorch tensor\n", " if not torch.is_tensor(tensor):\n", " raise ValueError('Input should be a PyTorch tensor')\n", " tensor = tensor.numpy().T # C, T\n", " plt.figure(figsize=(20, 20))\n", " plt.imshow(tensor, cmap='binary', interpolation='nearest', origin='lower')\n", " plt.show()\n", "\n", "# load salience and display the corresponding wav\n", "melody_prompt_wav, melody_prompt_sr = torchaudio.load(f\"./assets/{fnames[file_idx]}.wav\")\n", "print(\"Source melody:\")\n", "display_audio(melody_prompt_wav, sample_rate=melody_prompt_sr)\n", "melody = torch.load(f\"./assets/{fnames[file_idx]}.th\", weights_only=True)\n", "plot_chromagram(melody)\n", "print(\"Chords:\")\n", "print(chords[file_idx])\n", "\n", "# --------------------------------------------------\n", "# use demucs to seperate the drums stem from src mix\n", "# --------------------------------------------------\n", "def _get_drums_stem(wav: torch.Tensor, sample_rate: int) -> torch.Tensor:\n", " \"\"\"Get parts of the wav that holds the drums, extracting the main stems from the wav.\"\"\"\n", " demucs_model = pretrained.get_model('htdemucs').to('cuda')\n", " wav = convert_audio(\n", " wav, sample_rate, demucs_model.samplerate, demucs_model.audio_channels) # type: ignore\n", " stems = apply_model(demucs_model, wav.cuda().unsqueeze(0), device='cuda').squeeze(0)\n", " drum_stem = stems[demucs_model.sources.index('drums')] # extract relevant stems for drums conditioning\n", " return convert_audio(drum_stem.cpu(), demucs_model.samplerate, sample_rate, 1) # type: ignore\n", "drums_wav = _get_drums_stem(melody_prompt_wav, melody_prompt_sr)\n", "print(\"Separated drums:\")\n", "display_audio(drums_wav, sample_rate=melody_prompt_sr)\n", "\n", "# ----------------------------------\n", "# Generate using the loaded controls\n", "# ----------------------------------\n", "# these are free-form texts written randomly\n", "texts = [\n", " '90s rock with heavy drums and hammond',\n", " '80s pop with groovy synth bass and drum machine',\n", " 'folk song with leading accordion',\n", "]\n", "\n", "print(\"Generating...\")\n", "# replacing dynammic solver with simple euler solver\n", "model.set_generation_params(cfg_coef_all=1.5, cfg_coef_txt=2.5, euler=True, euler_steps=50) # manually set with euler solver\n", "output = model.generate_music(\n", " descriptions=texts,\n", " chords=chords[file_idx],\n", " drums_wav=drums_wav,\n", " drums_sample_rate=melody_prompt_sr,\n", " melody_salience_matrix=melody.permute(1, 0),\n", " progress=True\n", ")\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] } ], "metadata": { "kernelspec": { "display_name": "jasco_dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }