110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
"""SSH server implementation for Terminal Chat."""
|
|
|
|
import asyncio
|
|
import asyncssh
|
|
import os
|
|
import sys
|
|
from typing import Optional
|
|
|
|
from .models import UserConnection
|
|
from .chat_server import ChatServer
|
|
from simple_client import SimpleTextChatClient
|
|
|
|
|
|
class SSHChatServer(asyncssh.SSHServer):
|
|
"""SSH server that handles authentication for the chat."""
|
|
|
|
def __init__(self, chat_server: ChatServer):
|
|
self.chat_server = chat_server
|
|
|
|
def begin_auth(self, username: str) -> bool:
|
|
"""Allow any username to connect."""
|
|
return True
|
|
|
|
def password_auth_supported(self) -> bool:
|
|
"""Enable password authentication."""
|
|
return True
|
|
|
|
def validate_password(self, username: str, password: str) -> bool:
|
|
"""Accept any password for simplicity (for demo purposes)."""
|
|
# In production, you'd want proper authentication
|
|
# For now, just return True to accept any password
|
|
return True
|
|
|
|
def public_key_auth_supported(self) -> bool:
|
|
"""Disable public key auth for simplicity."""
|
|
return False
|
|
|
|
|
|
async def handle_client(process: asyncssh.SSHServerProcess, chat_server: ChatServer):
|
|
"""Handle an SSH client connection and run the chat app."""
|
|
username = process.channel.get_connection().get_extra_info('username')
|
|
|
|
# Check if username is already taken
|
|
if username in chat_server.users:
|
|
base_username = username
|
|
counter = 1
|
|
while username in chat_server.users:
|
|
username = f"{base_username}{counter}"
|
|
counter += 1
|
|
|
|
# Create message queue and user connection
|
|
message_queue = asyncio.Queue()
|
|
user_conn = UserConnection(username, message_queue)
|
|
chat_server.add_user(username, user_conn)
|
|
|
|
try:
|
|
# Create and run the simple text chat client
|
|
client = SimpleTextChatClient(username, message_queue, chat_server, process)
|
|
await client.run()
|
|
except Exception as e:
|
|
process.stderr.write(f"Error: {e}\n")
|
|
finally:
|
|
# Clean up
|
|
if username in chat_server.users:
|
|
chat_server.remove_user(username)
|
|
process.exit(0)
|
|
|
|
|
|
async def start_ssh_server(host: str, port: int, chat_server: ChatServer):
|
|
"""Start the SSH server for the chat application.
|
|
|
|
Args:
|
|
host: The host address to bind to
|
|
port: The port number to listen on
|
|
chat_server: The ChatServer instance to use
|
|
"""
|
|
# Generate host key if it doesn't exist
|
|
host_key_path = 'ssh_host_key'
|
|
if not os.path.exists(host_key_path):
|
|
print(f"Generating SSH host key at {host_key_path}...")
|
|
key = asyncssh.generate_private_key('ssh-rsa')
|
|
key.write_private_key(host_key_path)
|
|
print("Host key generated successfully")
|
|
|
|
async def process_factory(process: asyncssh.SSHServerProcess):
|
|
"""Factory function to handle SSH processes."""
|
|
await handle_client(process, chat_server)
|
|
|
|
print(f"Starting SSH server on {host}:{port}...")
|
|
print(f"Users can connect with: ssh -p {port} <username>@{host}")
|
|
print("Password: any password will work (press Enter for empty password)")
|
|
print("Press Ctrl+C to stop the server")
|
|
print("")
|
|
|
|
try:
|
|
await asyncssh.listen(
|
|
host,
|
|
port,
|
|
server_factory=lambda: SSHChatServer(chat_server),
|
|
server_host_keys=[host_key_path],
|
|
process_factory=process_factory,
|
|
encoding='utf-8',
|
|
)
|
|
|
|
# Keep the server running
|
|
await asyncio.Event().wait()
|
|
except (OSError, asyncssh.Error) as e:
|
|
print(f"Error starting SSH server: {e}")
|
|
raise
|