Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community/mlx_pipeline: fix crash at mlx call #29915

Merged
merged 3 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions libs/community/langchain_community/chat_models/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _stream(

try:
import mlx.core as mx
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import generate_step

except ImportError:
Expand All @@ -176,20 +177,28 @@ def _stream(
repetition_context_size: Optional[int] = model_kwargs.get(
"repetition_context_size", None
)
top_p: float = model_kwargs.get("top_p", 1.0)
min_p: float = model_kwargs.get("min_p", 0.0)
min_tokens_to_keep: int = model_kwargs.get("min_tokens_to_keep", 1)

llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np")

prompt_tokens = mx.array(llm_input[0])

eos_token_id = self.tokenizer.eos_token_id

sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep)

logits_processors = make_logits_processors(
None, repetition_penalty, repetition_context_size
)

for (token, prob), n in zip(
generate_step(
prompt_tokens,
self.llm.model,
temp=temp,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
sampler=sampler,
logits_processors=logits_processors,
),
range(max_new_tokens),
):
Expand Down
33 changes: 24 additions & 9 deletions libs/community/langchain_community/llms/mlx_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def from_model_id(

tokenizer_config = tokenizer_config or {}
if adapter_file:
model, tokenizer = load(model_id, tokenizer_config, adapter_file, lazy)
model, tokenizer = load(
model_id, tokenizer_config, adapter_path=adapter_file, lazy=lazy
)
else:
model, tokenizer = load(model_id, tokenizer_config, lazy=lazy)

Expand Down Expand Up @@ -141,6 +143,7 @@ def _call(
) -> str:
try:
from mlx_lm import generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler

except ImportError:
raise ImportError(
Expand All @@ -161,18 +164,23 @@ def _call(
"repetition_context_size", None
)
top_p: float = pipeline_kwargs.get("top_p", 1.0)
min_p: float = pipeline_kwargs.get("min_p", 0.0)
min_tokens_to_keep: int = pipeline_kwargs.get("min_tokens_to_keep", 1)

sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = make_logits_processors(
None, repetition_penalty, repetition_context_size
)

return generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
temp=temp,
max_tokens=max_tokens,
verbose=verbose,
formatter=formatter,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
top_p=top_p,
sampler=sampler,
logits_processors=logits_processors,
)

def _stream(
Expand All @@ -184,6 +192,7 @@ def _stream(
) -> Iterator[GenerationChunk]:
try:
import mlx.core as mx
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import generate_step

except ImportError:
Expand All @@ -203,6 +212,8 @@ def _stream(
"repetition_context_size", None
)
top_p: float = pipeline_kwargs.get("top_p", 1.0)
min_p: float = pipeline_kwargs.get("min_p", 0.0)
min_tokens_to_keep: int = pipeline_kwargs.get("min_tokens_to_keep", 1)

prompt = self.tokenizer.encode(prompt, return_tensors="np")

Expand All @@ -212,14 +223,18 @@ def _stream(
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()

sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep)

logits_processors = make_logits_processors(
None, repetition_penalty, repetition_context_size
)

for (token, prob), n in zip(
generate_step(
prompt=prompt_tokens,
model=self.model,
temp=temp,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
top_p=top_p,
sampler=sampler,
logits_processors=logits_processors,
),
range(max_new_tokens),
):
Expand Down
23 changes: 23 additions & 0 deletions libs/community/tests/integration_tests/llms/test_mlx_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Test MLX Pipeline wrapper."""

import pytest

from langchain_community.llms.mlx_pipeline import MLXPipeline


@pytest.mark.requires("mlx_lm")
def test_mlx_pipeline_text_generation() -> None:
"""Test valid call to MLX text generation model."""
llm = MLXPipeline.from_model_id(
Expand All @@ -13,6 +16,7 @@ def test_mlx_pipeline_text_generation() -> None:
assert isinstance(output, str)


@pytest.mark.requires("mlx_lm")
def test_init_with_model_and_tokenizer() -> None:
"""Test initialization with a HF pipeline."""
from mlx_lm import load
Expand All @@ -23,6 +27,7 @@ def test_init_with_model_and_tokenizer() -> None:
assert isinstance(output, str)


@pytest.mark.requires("mlx_lm")
def test_huggingface_pipeline_runtime_kwargs() -> None:
"""Test pipelines specifying the device map parameter."""
llm = MLXPipeline.from_model_id(
Expand All @@ -31,3 +36,21 @@ def test_huggingface_pipeline_runtime_kwargs() -> None:
prompt = "Say foo:"
output = llm.invoke(prompt, pipeline_kwargs={"max_tokens": 2})
assert len(output) < 10


@pytest.mark.requires("mlx_lm")
def test_mlx_pipeline_with_params() -> None:
"""Test valid call to MLX text generation model."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
pipeline_kwargs={
"max_tokens": 10,
"temp": 0.8,
"verbose": False,
"repetition_penalty": 1.1,
"repetition_context_size": 64,
"top_p": 0.95,
},
)
output = llm.invoke("Say foo:")
assert isinstance(output, str)
Loading